In this project, our goal was to predict stock prices using a machine learning approach. To achieve this, we designed and implemented a model based on a set of carefully chosen features. These features included technical indicators such as Relative Strength Index (RSI), Money Flow Index (MFI), Exponential Moving Averages (EMA), Simple Moving Average (SMA),Moving Average Convergence Divergence (MACD) as well as historical price data encompassing the previous 1 day, 3 days, 5 days, and 1, 2, 3, 4 weeks. Additionally, rolling average values for high, low, open, close, adjusted close, and volume were incorporated.
import os
import time
import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib.pyplot as plt
import seaborn as sns
from xgboost import plot_importance, plot_tree
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.svm import SVR
from sklearn.neighbors import KNeighborsRegressor
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from xgboost import XGBRegressor
from catboost import CatBoostRegressor
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
pd.set_option('display.max_columns', None)
# Chart drawing
import plotly as py
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
# Mute sklearn warnings
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
simplefilter(action='ignore', category=DeprecationWarning)
# Show charts when running kernel
#init_notebook_mode(connected=True)
# Change default background color for all visualizations
layout=go.Layout(paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(250,250,250,0.8)')
fig = go.Figure(layout=layout)
templated_fig = pio.to_templated(fig)
pio.templates['my_template'] = templated_fig.layout.template
pio.templates.default = 'my_template'
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="xgboost")
warnings.filterwarnings("ignore")
def evaluate_regression_model(y_true, y_pred):
"""
Calculate and print evaluation metrics for a regression model.
Parameters:
- y_true: Actual values.
- y_pred: Predicted values.
Returns:
- Dictionary containing the evaluation metrics.
"""
# Calculate evaluation metrics
mse = mean_squared_error(y_true, y_pred)
rmse = mean_squared_error(y_true, y_pred, squared=False)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
# Print the results
print(f'Mean Squared Error (MSE): {np.round(mse,3)}')
print(f'Root Mean Squared Error (RMSE): {np.round(rmse,3)}')
print(f'Mean Absolute Error (MAE): {np.round(mae,3)}')
print(f'R-squared (R2): {np.round(r2,3)}')
# Return results as a dictionary
results = {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'R2': r2
}
return results
def evaluate_regression_model2(y_true, y_pred):
"""
Calculate and print evaluation metrics for a regression model.
Parameters:
- y_true: Actual values.
- y_pred: Predicted values.
Returns:
- Dictionary containing the evaluation metrics.
"""
# Calculate evaluation metrics
mse = mean_squared_error(y_true, y_pred)
rmse = mean_squared_error(y_true, y_pred, squared=False)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
# # Print the results
# print(f'Mean Squared Error (MSE): {np.round(mse,3)}')
# print(f'Root Mean Squared Error (RMSE): {np.round(rmse,3)}')
# print(f'Mean Absolute Error (MAE): {np.round(mae,3)}')
# print(f'R-squared (R2): {np.round(r2,3)}')
# Return results as a dictionary
results = {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'R2': r2
}
return results
# Returns RSI values
def rsi(df, periods = 14):
close = df['close']
close_delta = close.diff()
# Make two series: one for lower closes and one for higher closes
up = close_delta.clip(lower=0)
down = -1 * close_delta.clip(upper=0)
ma_up = up.ewm(com = periods - 1, adjust=True, min_periods = periods).mean()
ma_down = down.ewm(com = periods - 1, adjust=True, min_periods = periods).mean()
rsi = ma_up / ma_down
rsi = 100 - (100/(1 + rsi))
return rsi
def gain(x):
return ((x > 0) * x).sum()
def loss(x):
return ((x < 0) * x).sum()
# Calculate money flow index
# Contributed by Github member and chatgpt
def mfi(df, n=14):
high = df['high']
low = df['low']
close = df['close']
volume = df['volume']
typical_price = (high + low + close) / 3
money_flow = typical_price * volume
mf_sign = np.where(typical_price > typical_price.shift(1), 1, -1)
signed_mf = money_flow * mf_sign
# Calculate gain and loss using vectorized operations
positive_mf = np.where(signed_mf > 0, signed_mf, 0)
negative_mf = np.where(signed_mf < 0, -signed_mf, 0)
mf_avg_gain = pd.Series(positive_mf).rolling(n, min_periods=1).sum()
mf_avg_loss = pd.Series(negative_mf).rolling(n, min_periods=1).sum()
return (100 - 100 / (1 + mf_avg_gain / mf_avg_loss)).to_numpy()
def plot_regression_accuracy(y_true, y_pred):
"""
Create various plots to evaluate the accuracy of a linear regression model.
Parameters:
- y_true: Actual values.
- y_pred: Predicted values.
"""
# Scatter Plot
plt.scatter(y_true, y_pred)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Scatter Plot of Actual vs Predicted Values')
plt.show()
# Residual Plot
residuals = y_true - y_pred
plt.scatter(y_pred, residuals)
plt.axhline(y=0, color='r', linestyle='--')
plt.xlabel('Predicted Values')
plt.ylabel('Residuals')
plt.title('Residual Plot')
plt.show()
# Distribution of Residuals
sns.histplot(residuals, kde=True)
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.title('Distribution of Residuals')
plt.show()
# Predicted vs Actual Line
plt.plot(y_true, y_true, linestyle='--', color='r', label='Perfect Fit')
plt.scatter(y_true, y_pred)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Predicted vs Actual Values with Perfect Fit Line')
plt.legend()
plt.show()
def plot_predictions(df,prediction):
plot_test_df= df[df.date.dt.year>=2020]
plot_test_df['prediction'] = prediction
fig = make_subplots(rows=2, cols=1)
fig.add_trace(go.Scatter(x=df.date, y=df.close_1d_next,
name='Truth',
marker_color='LightSkyBlue'), row=1, col=1)
fig.add_trace(go.Scatter(x=plot_test_df.date,
y=plot_test_df.prediction,
name='Prediction',
marker_color='MediumPurple'), row=1, col=1)
# Add title and Y-axis title for the first subplot
fig.update_layout(title_text='Train Data and Test Data', title_x=0.5, title_y=0.9)
fig.update_yaxes(title_text='Prediction', row=1, col=1)
fig.add_trace(go.Scatter(x=plot_test_df.date,
y=y_test,
name='Truth',
marker_color='LightSkyBlue',
showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=plot_test_df.date,
y=prediction,
name='Prediction',
marker_color='MediumPurple',
showlegend=False), row=2, col=1)
fig.update_yaxes(title_text='Prediction', row=2, col=1)
fig.show()
def plot_feature_importance(model,X_train,top_features):
# Get feature importance scores (coefficients)
feature_importance = model.coef_
# Create a DataFrame to store feature names and importance scores
feature_importance_df = pd.DataFrame({'Feature': X_train.columns, 'Importance': np.abs(feature_importance)})
# Sort features by importance
feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False).reset_index(drop=True)
top_features = top_features
# Plot feature importance
plt.figure(figsize=(20, 6))
plt.barh(range(len(feature_importance_df[:top_features])), feature_importance_df[:top_features]['Importance'], align="center")
plt.yticks(range(len(feature_importance_df[:top_features])), labels=feature_importance_df[:top_features]['Feature'])
plt.ylabel("Features")
plt.xlabel("Coefficient Magnitude")
plt.title(f"Top {top_features} Feature Importance Values")
plt.show()
return feature_importance_df
out_loc = '/Users/isapocan/Desktop/LSU/data/'
df = pd.read_parquet(out_loc+"stock_1d.parquet")
df.columns = df.columns.str.lower()
df.head()
| date | open | high | low | close | adj close | volume | symbol | security | gics sector | gics sub-industry | headquarters location | date added | cik | founded | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2013-01-02 | 94.190002 | 94.790001 | 93.959999 | 94.779999 | 67.895119 | 3206700.0 | MMM | 3M | Industrials | Industrial Conglomerates | Saint Paul, Minnesota | 1957-03-04 | 66740 | 1902 |
| 1 | 2013-01-03 | 94.339996 | 94.930000 | 94.129997 | 94.669998 | 67.816322 | 2704600.0 | MMM | 3M | Industrials | Industrial Conglomerates | Saint Paul, Minnesota | 1957-03-04 | 66740 | 1902 |
| 2 | 2013-01-04 | 94.790001 | 95.480003 | 94.540001 | 95.370003 | 68.317757 | 2704900.0 | MMM | 3M | Industrials | Industrial Conglomerates | Saint Paul, Minnesota | 1957-03-04 | 66740 | 1902 |
| 3 | 2013-01-07 | 95.019997 | 95.730003 | 94.760002 | 95.489998 | 68.403717 | 2745800.0 | MMM | 3M | Industrials | Industrial Conglomerates | Saint Paul, Minnesota | 1957-03-04 | 66740 | 1902 |
| 4 | 2013-01-08 | 95.169998 | 95.750000 | 95.099998 | 95.500000 | 68.410889 | 2655500.0 | MMM | 3M | Industrials | Industrial Conglomerates | Saint Paul, Minnesota | 1957-03-04 | 66740 | 1902 |
df = df[df['symbol']=='MDLZ']
Description: RSI helps you understand if a stock is likely to be overbought (prices too high) or oversold (prices too low). It looks at recent price changes to make this determination. Money Flow Index (MFI):
Description: MFI considers both price and trading volume to identify if a stock is overbought or oversold. It helps gauge the strength of buying and selling pressure.
Description: EMA smoothens out price data, giving more weight to recent prices. It reacts faster to price changes compared to a Simple Moving Average (SMA), making it useful for trend analysis.
Description: SMA is a basic average of stock prices over a specific period. It provides a smoothed representation of the overall price trend, helping to identify general market direction.
Description: MACD is a trend-following momentum indicator that shows the relationship between two moving averages of a security's price. It helps identify potential trend reversals or momentum shifts.
Description: The MACD signal line is a nine-day EMA of the MACD. It is used to generate trading signals. When the MACD crosses above the signal line, it might be a signal to buy, and when it crosses below, it might be a signal to sell.
df['ema_9'] = df['close'].ewm(9).mean().shift()
df['sma_5'] = df['close'].rolling(5).mean().shift()
df['sma_10'] = df['close'].rolling(10).mean().shift()
df['sma_15'] = df['close'].rolling(15).mean().shift()
df['sma_30'] = df['close'].rolling(30).mean().shift()
df.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 2733 entries, 852843 to 855575 Data columns (total 20 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 date 2733 non-null datetime64[ns] 1 open 2733 non-null float64 2 high 2733 non-null float64 3 low 2733 non-null float64 4 close 2733 non-null float64 5 adj close 2733 non-null float64 6 volume 2733 non-null float64 7 symbol 2733 non-null object 8 security 2733 non-null object 9 gics sector 2733 non-null object 10 gics sub-industry 2733 non-null object 11 headquarters location 2733 non-null object 12 date added 2733 non-null object 13 cik 2733 non-null int64 14 founded 2733 non-null object 15 ema_9 2732 non-null float64 16 sma_5 2728 non-null float64 17 sma_10 2723 non-null float64 18 sma_15 2718 non-null float64 19 sma_30 2703 non-null float64 dtypes: datetime64[ns](1), float64(11), int64(1), object(7) memory usage: 448.4+ KB
df['rsi'] = rsi(df) #.fillna(0)
df['mfi'] = mfi(df, 14)
df[['date','close','ema_9','sma_5','sma_10','sma_15','sma_30','rsi','mfi']]
| date | close | ema_9 | sma_5 | sma_10 | sma_15 | sma_30 | rsi | mfi | |
|---|---|---|---|---|---|---|---|---|---|
| 852843 | 2013-01-02 | 26.670000 | NaN | NaN | NaN | NaN | NaN | NaN | 0.000000 |
| 852844 | 2013-01-03 | 26.639999 | 26.670000 | NaN | NaN | NaN | NaN | NaN | 33.904295 |
| 852845 | 2013-01-04 | 26.740000 | 26.654210 | NaN | NaN | NaN | NaN | NaN | 48.695375 |
| 852846 | 2013-01-07 | 26.660000 | 26.685867 | NaN | NaN | NaN | NaN | NaN | 39.919745 |
| 852847 | 2013-01-08 | 26.680000 | 26.678345 | NaN | NaN | NaN | NaN | NaN | 55.233142 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 855571 | 2023-11-02 | 67.970001 | 65.538988 | 65.934001 | 65.321001 | 64.406001 | 65.993000 | 60.257764 | 89.207420 |
| 855572 | 2023-11-03 | 68.820000 | 65.782090 | 66.398001 | 65.697001 | 64.868001 | 65.901667 | 63.726091 | 89.458580 |
| 855573 | 2023-11-06 | 68.239998 | 66.085881 | 67.160001 | 66.169001 | 65.354001 | 65.848000 | 59.885606 | 83.710782 |
| 855574 | 2023-11-07 | 68.489998 | 66.301292 | 67.612000 | 66.594001 | 65.728667 | 65.799000 | 60.977252 | 75.937617 |
| 855575 | 2023-11-08 | 69.019997 | 66.520163 | 68.067999 | 66.888000 | 66.058667 | 65.729667 | 63.259914 | 75.566164 |
2733 rows × 9 columns
df[['rsi','mfi']].corr()
| rsi | mfi | |
|---|---|---|
| rsi | 1.000000 | 0.698958 |
| mfi | 0.698958 | 1.000000 |
EMA_12 = pd.Series(df['close'].ewm(span=12, min_periods=12).mean())
EMA_26 = pd.Series(df['close'].ewm(span=26, min_periods=26).mean())
df['macd'] = pd.Series(EMA_12 - EMA_26)
df['macd_signal'] = pd.Series(df.macd.ewm(span=9, min_periods=9).mean())
df[(~df['macd'].isna()) & (~df['macd_signal'].isna())][['macd','macd_signal']].head()
| macd | macd_signal | |
|---|---|---|
| 852876 | -0.147786 | -0.050945 |
| 852877 | -0.175230 | -0.078792 |
| 852878 | -0.198438 | -0.104970 |
| 852879 | -0.235462 | -0.132994 |
| 852880 | -0.226841 | -0.152855 |
### predict next day
df['close_1d_next'] = df['close'].shift(-1)
df[['date','close','close_1d_next']].head()
| date | close | close_1d_next | |
|---|---|---|---|
| 852843 | 2013-01-02 | 26.670000 | 26.639999 |
| 852844 | 2013-01-03 | 26.639999 | 26.740000 |
| 852845 | 2013-01-04 | 26.740000 | 26.660000 |
| 852846 | 2013-01-07 | 26.660000 | 26.680000 |
| 852847 | 2013-01-08 | 26.680000 | 27.049999 |
df['close_1d_ago'] = df['close'].shift(1)
df['close_3d_ago'] = df['close'].shift(3)
df['close_5d_ago'] = df['close'].shift(5)
df['close_1w_ago'] = df['close'].shift(7)
df['close_2w_ago'] = df['close'].shift(14)
df['close_3w_ago'] = df['close'].shift(21)
df['close_4w_ago'] = df['close'].shift(28)
df['adj_close_1d_ago'] = df['adj close'].shift(1)
df['adj_close_3d_ago'] = df['adj close'].shift(3)
df['adj_close_5d_ago'] = df['adj close'].shift(5)
df['adj_close_1w_ago'] = df['adj close'].shift(7)
df['adj_close_2w_ago'] = df['adj close'].shift(14)
df['adj_close_3w_ago'] = df['adj close'].shift(21)
df['adj_close_4w_ago'] = df['adj close'].shift(28)
df['open_1d_ago'] = df['open'].shift(1)
df['open_3d_ago'] = df['open'].shift(3)
df['open_5d_ago'] = df['open'].shift(5)
df['open_1w_ago'] = df['open'].shift(7)
df['open_2w_ago'] = df['open'].shift(14)
df['open_3w_ago'] = df['open'].shift(21)
df['open_4w_ago'] = df['open'].shift(28)
df['high_1d_ago'] = df['high'].shift(1)
df['high_3d_ago'] = df['high'].shift(3)
df['high_5d_ago'] = df['high'].shift(5)
df['high_1w_ago'] = df['high'].shift(7)
df['high_2w_ago'] = df['high'].shift(14)
df['high_3w_ago'] = df['high'].shift(21)
df['high_4w_ago'] = df['high'].shift(28)
df['low_1d_ago'] = df['low'].shift(1)
df['low_3d_ago'] = df['low'].shift(3)
df['low_5d_ago'] = df['low'].shift(5)
df['low_1w_ago'] = df['low'].shift(7)
df['low_2w_ago'] = df['low'].shift(14)
df['low_3w_ago'] = df['low'].shift(21)
df['low_4w_ago'] = df['low'].shift(28)
df['volume_1d_ago'] = df['volume'].shift(1)
df['volume_3d_ago'] = df['volume'].shift(3)
df['volume_5d_ago'] = df['volume'].shift(5)
df['volume_1w_ago'] = df['volume'].shift(7)
df['volume_2w_ago'] = df['volume'].shift(14)
df['volume_3w_ago'] = df['volume'].shift(21)
df['volume_4w_ago'] = df['volume'].shift(28)
df['open_3d_avg'] = df['open'].rolling(window=3).mean()
df['open_5d_avg'] = df['open'].rolling(window=5).mean()
df['open_7d_avg'] = df['open'].rolling(window=7).mean()
df['open_10d_avg'] = df['open'].rolling(window=10).mean()
df['open_15d_avg'] = df['open'].rolling(window=15).mean()
df['open_30d_avg'] = df['open'].rolling(window=30).mean()
df['high_3d_avg'] = df['high'].rolling(window=3).mean()
df['high_5d_avg'] = df['high'].rolling(window=5).mean()
df['high_7d_avg'] = df['high'].rolling(window=7).mean()
df['high_10d_avg'] = df['high'].rolling(window=10).mean()
df['high_15d_avg'] = df['high'].rolling(window=15).mean()
df['high_30d_avg'] = df['high'].rolling(window=30).mean()
df['low_3d_avg'] = df['low'].rolling(window=3).mean()
df['low_5d_avg'] = df['low'].rolling(window=5).mean()
df['low_7d_avg'] = df['low'].rolling(window=7).mean()
df['low_10d_avg'] = df['low'].rolling(window=10).mean()
df['low_15d_avg'] = df['low'].rolling(window=15).mean()
df['low_30d_avg'] = df['low'].rolling(window=30).mean()
df['volume_3d_avg'] = df['volume'].rolling(window=3).mean()
df['volume_5d_avg'] = df['volume'].rolling(window=5).mean()
df['volume_7d_avg'] = df['volume'].rolling(window=7).mean()
df['volume_10d_avg'] = df['volume'].rolling(window=10).mean()
df['volume_15d_avg'] = df['volume'].rolling(window=15).mean()
df['volume_30d_avg'] = df['volume'].rolling(window=30).mean()
df['adj_close_3d_avg'] = df['adj close'].rolling(window=3).mean()
df['adj_close_5d_avg'] = df['adj close'].rolling(window=5).mean()
df['adj_close_7d_avg'] = df['adj close'].rolling(window=7).mean()
df['adj_close_10d_avg'] = df['adj close'].rolling(window=10).mean()
df['adj_close_15d_avg'] = df['adj close'].rolling(window=15).mean()
df['adj_close_30d_avg'] = df['adj close'].rolling(window=30).mean()
df = df.dropna().reset_index(drop=True)
df.head()
| date | open | high | low | close | adj close | volume | symbol | security | gics sector | gics sub-industry | headquarters location | date added | cik | founded | ema_9 | sma_5 | sma_10 | sma_15 | sma_30 | rsi | mfi | macd | macd_signal | close_1d_next | close_1d_ago | close_3d_ago | close_5d_ago | close_1w_ago | close_2w_ago | close_3w_ago | close_4w_ago | adj_close_1d_ago | adj_close_3d_ago | adj_close_5d_ago | adj_close_1w_ago | adj_close_2w_ago | adj_close_3w_ago | adj_close_4w_ago | open_1d_ago | open_3d_ago | open_5d_ago | open_1w_ago | open_2w_ago | open_3w_ago | open_4w_ago | high_1d_ago | high_3d_ago | high_5d_ago | high_1w_ago | high_2w_ago | high_3w_ago | high_4w_ago | low_1d_ago | low_3d_ago | low_5d_ago | low_1w_ago | low_2w_ago | low_3w_ago | low_4w_ago | volume_1d_ago | volume_3d_ago | volume_5d_ago | volume_1w_ago | volume_2w_ago | volume_3w_ago | volume_4w_ago | open_3d_avg | open_5d_avg | open_7d_avg | open_10d_avg | open_15d_avg | open_30d_avg | high_3d_avg | high_5d_avg | high_7d_avg | high_10d_avg | high_15d_avg | high_30d_avg | low_3d_avg | low_5d_avg | low_7d_avg | low_10d_avg | low_15d_avg | low_30d_avg | volume_3d_avg | volume_5d_avg | volume_7d_avg | volume_10d_avg | volume_15d_avg | volume_30d_avg | adj_close_3d_avg | adj_close_5d_avg | adj_close_7d_avg | adj_close_10d_avg | adj_close_15d_avg | adj_close_30d_avg | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2013-02-20 | 27.070000 | 27.150000 | 26.950001 | 27.030001 | 21.735399 | 17057200.0 | MDLZ | Mondelez International | Consumer Staples | Packaged Foods & Meats | Chicago, Illinois | 2012-10-02 | 1103982 | 2012 | 27.499536 | 27.136 | 27.518 | 27.642000 | 27.589000 | 41.633625 | 53.176274 | -0.147786 | -0.050945 | 26.820000 | 26.959999 | 26.570000 | 27.680000 | 27.76 | 27.730000 | 28.080000 | 27.049999 | 21.679117 | 21.365499 | 22.258080 | 22.322405 | 22.298285 | 22.579723 | 21.751484 | 26.750000 | 26.690001 | 27.700001 | 27.799999 | 27.830000 | 27.969999 | 26.790001 | 27.190001 | 27.020000 | 27.830000 | 28.100000 | 27.980000 | 28.100000 | 27.080000 | 26.750000 | 26.450001 | 27.270000 | 27.750000 | 27.67 | 27.820000 | 26.68 | 18297500.0 | 37728900.0 | 14931000.0 | 11159200.0 | 5800400.0 | 15906900.0 | 11671400.0 | 26.886667 | 27.018 | 27.217143 | 27.386 | 27.553333 | 27.536667 | 27.136667 | 27.248000 | 27.410000 | 27.618 | 27.779333 | 27.754667 | 26.766667 | 26.842 | 27.015714 | 27.224 | 27.411333 | 27.382333 | 1.904973e+07 | 21756140.0 | 1.907480e+07 | 17005580.0 | 1.419575e+07 | 1.352419e+07 | 21.633545 | 21.716101 | 21.878994 | 22.053831 | 22.184635 | 22.194819 |
| 1 | 2013-02-21 | 26.990000 | 27.049999 | 26.639999 | 26.820000 | 21.566534 | 16936600.0 | MDLZ | Mondelez International | Consumer Staples | Packaged Foods & Meats | Chicago, Illinois | 2012-10-02 | 1103982 | 2012 | 27.451239 | 27.006 | 27.426 | 27.588667 | 27.601333 | 38.257648 | 47.431888 | -0.175230 | -0.078792 | 26.770000 | 27.030001 | 26.719999 | 27.750000 | 27.75 | 27.790001 | 27.559999 | 27.309999 | 21.735399 | 21.486118 | 22.314371 | 22.314371 | 22.346525 | 22.161583 | 21.960548 | 27.070000 | 26.840000 | 27.740000 | 27.730000 | 27.650000 | 27.730000 | 27.129999 | 27.150000 | 27.070000 | 27.809999 | 27.799999 | 27.950001 | 28.040001 | 27.340000 | 26.950001 | 26.600000 | 27.459999 | 27.629999 | 27.65 | 27.299999 | 27.09 | 17057200.0 | 21794500.0 | 13902600.0 | 9811900.0 | 7541300.0 | 18213200.0 | 16348500.0 | 26.936666 | 26.868 | 27.111429 | 27.295 | 27.497333 | 27.552333 | 27.130000 | 27.096000 | 27.302857 | 27.510 | 27.717333 | 27.759000 | 26.780000 | 26.678 | 26.874286 | 27.108 | 27.342667 | 27.388333 | 1.743043e+07 | 22362940.0 | 2.009261e+07 | 17608410.0 | 1.493817e+07 | 1.361005e+07 | 21.660350 | 21.566534 | 21.772160 | 21.958945 | 22.135851 | 22.198572 |
| 2 | 2013-02-22 | 26.889999 | 27.129999 | 26.730000 | 26.770000 | 21.526327 | 16664800.0 | MDLZ | Mondelez International | Consumer Staples | Packaged Foods & Meats | Chicago, Illinois | 2012-10-02 | 1103982 | 2012 | 27.386494 | 26.820 | 27.308 | 27.528000 | 27.606000 | 37.478423 | 48.958416 | -0.198438 | -0.104970 | 26.490000 | 26.820000 | 26.959999 | 26.570000 | 27.68 | 28.219999 | 27.790001 | 27.420000 | 21.566534 | 21.679117 | 21.365499 | 22.258080 | 22.692308 | 22.346525 | 22.049007 | 26.990000 | 26.750000 | 26.690001 | 27.700001 | 28.000000 | 27.500000 | 27.350000 | 27.049999 | 27.190001 | 27.020000 | 27.830000 | 28.320000 | 27.889999 | 27.540001 | 26.639999 | 26.750000 | 26.450001 | 27.270000 | 27.93 | 27.350000 | 27.25 | 16936600.0 | 18297500.0 | 37728900.0 | 14931000.0 | 9623100.0 | 15212300.0 | 10162600.0 | 26.983333 | 26.908 | 26.995714 | 27.220 | 27.446667 | 27.555667 | 27.109999 | 27.118000 | 27.202857 | 27.415 | 27.662666 | 27.760667 | 26.773333 | 26.734 | 26.797143 | 27.023 | 27.281333 | 27.390000 | 1.688620e+07 | 18150120.0 | 2.034030e+07 | 17828420.0 | 1.554640e+07 | 1.377650e+07 | 21.609420 | 21.598699 | 21.667624 | 21.856822 | 22.081171 | 22.191067 |
| 3 | 2013-02-25 | 26.790001 | 27.080000 | 26.480000 | 26.490000 | 21.301172 | 15527100.0 | MDLZ | Mondelez International | Consumer Staples | Packaged Foods & Meats | Chicago, Illinois | 2012-10-02 | 1103982 | 2012 | 27.323424 | 26.860 | 27.181 | 27.460000 | 27.596667 | 33.378362 | 47.675126 | -0.235462 | -0.132994 | 26.950001 | 26.770000 | 27.030001 | 26.719999 | 27.75 | 27.879999 | 27.830000 | 27.480000 | 21.526327 | 21.735399 | 21.486118 | 22.314371 | 22.418896 | 22.378695 | 22.097254 | 26.889999 | 27.070000 | 26.840000 | 27.740000 | 28.010000 | 27.930000 | 27.459999 | 27.129999 | 27.150000 | 27.070000 | 27.809999 | 28.150000 | 28.030001 | 27.520000 | 26.730000 | 26.950001 | 26.600000 | 27.459999 | 27.83 | 27.639999 | 27.17 | 16664800.0 | 17057200.0 | 21794500.0 | 13902600.0 | 8954300.0 | 14444500.0 | 8688200.0 | 26.890000 | 26.898 | 26.860000 | 27.119 | 27.366000 | 27.544333 | 27.086666 | 27.120000 | 27.098571 | 27.313 | 27.580000 | 27.752000 | 26.616666 | 26.710 | 26.657143 | 26.896 | 27.184667 | 27.369667 | 1.637617e+07 | 16896640.0 | 2.057237e+07 | 18265210.0 | 1.594000e+07 | 1.374912e+07 | 21.464678 | 21.561710 | 21.522881 | 21.754699 | 21.988429 | 22.169087 |
| 4 | 2013-02-26 | 26.530001 | 26.980000 | 26.510000 | 26.950001 | 21.671074 | 13702900.0 | MDLZ | Mondelez International | Consumer Staples | Packaged Foods & Meats | Chicago, Illinois | 2012-10-02 | 1103982 | 2012 | 27.238357 | 26.814 | 27.054 | 27.344667 | 27.569333 | 44.181951 | 48.178912 | -0.226841 | -0.152855 | 27.570000 | 26.490000 | 26.820000 | 26.959999 | 26.57 | 27.950001 | 27.780001 | 27.709999 | 21.301172 | 21.566534 | 21.679117 | 21.365499 | 22.475189 | 22.338484 | 22.282200 | 26.790001 | 26.990000 | 26.750000 | 26.690001 | 27.950001 | 27.830000 | 27.580000 | 27.080000 | 27.049999 | 27.190001 | 27.020000 | 28.110001 | 27.889999 | 27.740000 | 26.480000 | 26.639999 | 26.750000 | 26.450001 | 27.85 | 27.690001 | 27.34 | 15527100.0 | 16936600.0 | 18297500.0 | 37728900.0 | 10961400.0 | 12066800.0 | 9863200.0 | 26.736667 | 26.854 | 26.837143 | 26.999 | 27.267333 | 27.517000 | 27.063333 | 27.077999 | 27.092857 | 27.231 | 27.502000 | 27.733333 | 26.573333 | 26.662 | 26.665714 | 26.784 | 27.096667 | 27.345000 | 1.529827e+07 | 15977720.0 | 1.714009e+07 | 18654310.0 | 1.625657e+07 | 1.386713e+07 | 21.499524 | 21.560101 | 21.566535 | 21.690369 | 21.938574 | 22.156490 |
# # Calculate the index for the 70-30 split
# split_index = int(0.7 * len(df))
# # Split the DataFrame into training and testing sets
# train_df = df.iloc[:split_index]
# test_df = df.iloc[split_index:]
# Split the DataFrame into training and testing sets
train_df = df[df.date.dt.year<2020]
test_df = df[df.date.dt.year>=2020]
print(f"Train days: {len(train_df)}, Test days: {len(test_df)}")
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_df.date, y=train_df.close_1d_next, name='Training'))
fig.add_trace(go.Scatter(x=test_df.date, y=test_df.close_1d_next, name='Test'))
fig.show()
Train days: 1729, Test days: 970
drop_cols1 = ['date','open','high','low','close','adj close','volume','symbol','security',
'gics sector','gics sub-industry','headquarters location','date added','cik','founded']
train_df = train_df.drop(drop_cols1, 1)
test_df = test_df.drop(drop_cols1, 1)
# target column is next day's close price
y_train = train_df['close_1d_next'].copy()
X_train = train_df.drop(['close_1d_next'], 1)
# target column is next day's close price
y_test = test_df['close_1d_next'].copy()
X_test = test_df.drop(['close_1d_next'], 1)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
X_train.shape, X_train_scaled.shape, X_test.shape, X_test_scaled.shape,
((1729, 81), (1729, 81), (970, 81), (970, 81))
def train_and_evaluate_models(X_train_scaled,y_train,X_test_scaled,y_test):
"""
Train and evaluate multiple regression models on a given dataframe.
Parameters:
- dataframe: Pandas DataFrame containing the dataset.
- target_column: Name of the target column (dependent variable).
- features_columns: List of column names used as features (independent variables).
Returns:
- A DataFrame containing evaluation metrics for each model.
"""
# Split the data into features (X) and target variable (y)
# X = dataframe[features_columns]
# y = dataframe[target_column]
# Split the data into training and testing sets (70-30 split)
#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# scaler = StandardScaler()
# X_train_scaled = scaler.fit_transform(X_train)
# X_test_scaled = scaler.transform(X_test)
# Initialize the models
models = {
'Linear Regression': LinearRegression(),
'Ridge Regression': Ridge(),
'Lasso Regression': Lasso(),
'Elastic Net': ElasticNet(),
'SVR': SVR(),
'K-Neighbors Regressor': KNeighborsRegressor(),
'Decision Tree': DecisionTreeRegressor(),
'Random Forest': RandomForestRegressor(),
'Gradient Boosting': GradientBoostingRegressor(),
'AdaBoost': AdaBoostRegressor(),
'XGBoost': XGBRegressor(),
'CatBoost': CatBoostRegressor()
}
# Initialize a DataFrame to store the evaluation metrics
metrics_df = pd.DataFrame(columns=['Model', 'Mean Squared Error', 'Mean Absolute Error', 'R2 Score'])
# Train and evaluate each model
for model_name, model in models.items():
start_time = time.time()
# Train the model
model.fit(X_train_scaled, y_train)
end_time = time.time() # Record the end time
training_time = end_time - start_time
# Make predictions
y_pred = model.predict(X_test_scaled)
# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
# Store the metrics in the DataFrame
metrics_df = metrics_df.append({
'Model': model_name,
'Mean Squared Error': mse,
'Mean Absolute Error': mae,
'R2 Score': r2,
'Training Time (s)': training_time
}, ignore_index=True)
metrics_df = metrics_df.sort_values(by=['R2 Score'],ascending=False)
return metrics_df
df_compare = train_and_evaluate_models(X_train,y_train,X_test,y_test)
Learning rate set to 0.044643 0: learn: 6.0670144 total: 64.1ms remaining: 1m 4s 1: learn: 5.8228277 total: 67.3ms remaining: 33.6s 2: learn: 5.5825438 total: 70.6ms remaining: 23.5s 3: learn: 5.3589845 total: 74.2ms remaining: 18.5s 4: learn: 5.1400669 total: 78.2ms remaining: 15.6s 5: learn: 4.9338718 total: 81.7ms remaining: 13.5s 6: learn: 4.7348036 total: 85ms remaining: 12.1s 7: learn: 4.5462470 total: 88.4ms remaining: 11s 8: learn: 4.3663666 total: 92ms remaining: 10.1s 9: learn: 4.1940087 total: 95.6ms remaining: 9.47s 10: learn: 4.0326464 total: 99.2ms remaining: 8.92s 11: learn: 3.8728104 total: 103ms remaining: 8.46s 12: learn: 3.7241421 total: 106ms remaining: 8.06s 13: learn: 3.5819820 total: 110ms remaining: 7.75s 14: learn: 3.4443416 total: 113ms remaining: 7.45s 15: learn: 3.3131010 total: 117ms remaining: 7.2s 16: learn: 3.1862951 total: 121ms remaining: 6.97s 17: learn: 3.0636007 total: 124ms remaining: 6.77s 18: learn: 2.9525545 total: 128ms remaining: 6.6s 19: learn: 2.8449392 total: 131ms remaining: 6.44s 20: learn: 2.7377198 total: 135ms remaining: 6.29s 21: learn: 2.6359381 total: 139ms remaining: 6.16s 22: learn: 2.5393906 total: 142ms remaining: 6.04s 23: learn: 2.4447919 total: 146ms remaining: 5.92s 24: learn: 2.3592925 total: 149ms remaining: 5.81s 25: learn: 2.2778166 total: 153ms remaining: 5.73s 26: learn: 2.1955454 total: 157ms remaining: 5.65s 27: learn: 2.1181936 total: 160ms remaining: 5.56s 28: learn: 2.0456841 total: 163ms remaining: 5.47s 29: learn: 1.9773187 total: 167ms remaining: 5.4s 30: learn: 1.9140404 total: 171ms remaining: 5.33s 31: learn: 1.8503968 total: 174ms remaining: 5.26s 32: learn: 1.7872718 total: 178ms remaining: 5.2s 33: learn: 1.7298101 total: 181ms remaining: 5.13s 34: learn: 1.6773427 total: 184ms remaining: 5.07s 35: learn: 1.6263746 total: 187ms remaining: 5.02s 36: learn: 1.5799050 total: 191ms remaining: 4.98s 37: learn: 1.5350898 total: 195ms remaining: 4.95s 38: learn: 1.4896754 total: 199ms remaining: 4.89s 39: learn: 1.4440908 total: 202ms remaining: 4.84s 40: learn: 1.4038854 total: 205ms remaining: 4.81s 41: learn: 1.3654869 total: 210ms remaining: 4.79s 42: learn: 1.3304415 total: 214ms remaining: 4.76s 43: learn: 1.2940654 total: 218ms remaining: 4.73s 44: learn: 1.2608561 total: 221ms remaining: 4.7s 45: learn: 1.2267358 total: 225ms remaining: 4.66s 46: learn: 1.1934982 total: 228ms remaining: 4.62s 47: learn: 1.1672852 total: 231ms remaining: 4.58s 48: learn: 1.1375280 total: 234ms remaining: 4.55s 49: learn: 1.1138729 total: 238ms remaining: 4.52s 50: learn: 1.0899582 total: 241ms remaining: 4.49s 51: learn: 1.0649512 total: 244ms remaining: 4.46s 52: learn: 1.0415461 total: 248ms remaining: 4.42s 53: learn: 1.0206758 total: 251ms remaining: 4.4s 54: learn: 0.9991569 total: 254ms remaining: 4.37s 55: learn: 0.9795833 total: 258ms remaining: 4.35s 56: learn: 0.9603286 total: 261ms remaining: 4.32s 57: learn: 0.9430651 total: 264ms remaining: 4.29s 58: learn: 0.9257478 total: 267ms remaining: 4.26s 59: learn: 0.9087157 total: 271ms remaining: 4.25s 60: learn: 0.8934800 total: 274ms remaining: 4.22s 61: learn: 0.8810773 total: 277ms remaining: 4.19s 62: learn: 0.8666743 total: 280ms remaining: 4.17s 63: learn: 0.8534041 total: 284ms remaining: 4.15s 64: learn: 0.8414198 total: 287ms remaining: 4.13s 65: learn: 0.8286451 total: 290ms remaining: 4.11s 66: learn: 0.8182783 total: 293ms remaining: 4.08s 67: learn: 0.8072075 total: 296ms remaining: 4.06s 68: learn: 0.7983925 total: 299ms remaining: 4.04s 69: learn: 0.7894196 total: 303ms remaining: 4.02s 70: learn: 0.7802497 total: 306ms remaining: 4s 71: learn: 0.7722262 total: 309ms remaining: 3.98s 72: learn: 0.7636665 total: 312ms remaining: 3.97s 73: learn: 0.7559532 total: 316ms remaining: 3.95s 74: learn: 0.7484620 total: 320ms remaining: 3.95s 75: learn: 0.7409614 total: 324ms remaining: 3.94s 76: learn: 0.7340694 total: 327ms remaining: 3.92s 77: learn: 0.7272183 total: 330ms remaining: 3.91s 78: learn: 0.7206823 total: 334ms remaining: 3.89s 79: learn: 0.7145614 total: 337ms remaining: 3.87s 80: learn: 0.7090921 total: 340ms remaining: 3.86s 81: learn: 0.7039718 total: 343ms remaining: 3.84s 82: learn: 0.6989219 total: 346ms remaining: 3.83s 83: learn: 0.6946642 total: 350ms remaining: 3.81s 84: learn: 0.6902238 total: 353ms remaining: 3.8s 85: learn: 0.6850230 total: 356ms remaining: 3.78s 86: learn: 0.6809441 total: 359ms remaining: 3.77s 87: learn: 0.6778489 total: 362ms remaining: 3.76s 88: learn: 0.6740331 total: 366ms remaining: 3.74s 89: learn: 0.6704494 total: 369ms remaining: 3.73s 90: learn: 0.6661259 total: 372ms remaining: 3.71s 91: learn: 0.6626042 total: 375ms remaining: 3.7s 92: learn: 0.6590292 total: 378ms remaining: 3.69s 93: learn: 0.6554413 total: 381ms remaining: 3.67s 94: learn: 0.6526417 total: 384ms remaining: 3.66s 95: learn: 0.6490439 total: 387ms remaining: 3.65s 96: learn: 0.6467224 total: 391ms remaining: 3.64s 97: learn: 0.6432002 total: 395ms remaining: 3.63s 98: learn: 0.6412776 total: 398ms remaining: 3.62s 99: learn: 0.6389192 total: 402ms remaining: 3.62s 100: learn: 0.6359101 total: 406ms remaining: 3.62s 101: learn: 0.6322537 total: 411ms remaining: 3.62s 102: learn: 0.6295722 total: 415ms remaining: 3.62s 103: learn: 0.6266020 total: 419ms remaining: 3.61s 104: learn: 0.6242294 total: 423ms remaining: 3.6s 105: learn: 0.6217499 total: 426ms remaining: 3.59s 106: learn: 0.6204236 total: 429ms remaining: 3.58s 107: learn: 0.6179910 total: 432ms remaining: 3.57s 108: learn: 0.6154061 total: 436ms remaining: 3.56s 109: learn: 0.6137596 total: 439ms remaining: 3.55s 110: learn: 0.6117283 total: 442ms remaining: 3.54s 111: learn: 0.6098140 total: 445ms remaining: 3.53s 112: learn: 0.6075496 total: 448ms remaining: 3.52s 113: learn: 0.6058904 total: 451ms remaining: 3.51s 114: learn: 0.6038122 total: 454ms remaining: 3.5s 115: learn: 0.6021583 total: 458ms remaining: 3.49s 116: learn: 0.6007633 total: 462ms remaining: 3.48s 117: learn: 0.5989211 total: 465ms remaining: 3.47s 118: learn: 0.5972269 total: 468ms remaining: 3.46s 119: learn: 0.5953529 total: 471ms remaining: 3.46s 120: learn: 0.5937351 total: 474ms remaining: 3.45s 121: learn: 0.5917424 total: 478ms remaining: 3.44s 122: learn: 0.5902046 total: 481ms remaining: 3.43s 123: learn: 0.5883267 total: 484ms remaining: 3.42s 124: learn: 0.5870216 total: 487ms remaining: 3.41s 125: learn: 0.5853003 total: 490ms remaining: 3.4s 126: learn: 0.5843247 total: 493ms remaining: 3.39s 127: learn: 0.5823154 total: 497ms remaining: 3.38s 128: learn: 0.5809541 total: 500ms remaining: 3.37s 129: learn: 0.5789569 total: 503ms remaining: 3.37s 130: learn: 0.5773752 total: 506ms remaining: 3.36s 131: learn: 0.5752670 total: 510ms remaining: 3.35s 132: learn: 0.5737710 total: 513ms remaining: 3.34s 133: learn: 0.5722145 total: 516ms remaining: 3.33s 134: learn: 0.5706925 total: 519ms remaining: 3.33s 135: learn: 0.5696125 total: 522ms remaining: 3.32s 136: learn: 0.5684057 total: 525ms remaining: 3.31s 137: learn: 0.5668800 total: 528ms remaining: 3.3s 138: learn: 0.5652156 total: 531ms remaining: 3.29s 139: learn: 0.5641754 total: 534ms remaining: 3.28s 140: learn: 0.5627745 total: 538ms remaining: 3.27s 141: learn: 0.5621379 total: 541ms remaining: 3.27s 142: learn: 0.5609717 total: 544ms remaining: 3.26s 143: learn: 0.5595847 total: 547ms remaining: 3.25s 144: learn: 0.5582880 total: 550ms remaining: 3.24s 145: learn: 0.5573794 total: 553ms remaining: 3.23s 146: learn: 0.5561539 total: 556ms remaining: 3.23s 147: learn: 0.5548667 total: 559ms remaining: 3.22s 148: learn: 0.5533491 total: 562ms remaining: 3.21s 149: learn: 0.5519277 total: 566ms remaining: 3.21s 150: learn: 0.5510423 total: 569ms remaining: 3.2s 151: learn: 0.5500217 total: 572ms remaining: 3.19s 152: learn: 0.5485391 total: 575ms remaining: 3.18s 153: learn: 0.5472895 total: 578ms remaining: 3.18s 154: learn: 0.5464558 total: 581ms remaining: 3.17s 155: learn: 0.5456874 total: 584ms remaining: 3.16s 156: learn: 0.5443351 total: 588ms remaining: 3.16s 157: learn: 0.5431722 total: 591ms remaining: 3.15s 158: learn: 0.5420827 total: 594ms remaining: 3.14s 159: learn: 0.5404464 total: 598ms remaining: 3.14s 160: learn: 0.5394812 total: 602ms remaining: 3.14s 161: learn: 0.5383878 total: 607ms remaining: 3.14s 162: learn: 0.5371991 total: 610ms remaining: 3.13s 163: learn: 0.5360893 total: 614ms remaining: 3.13s 164: learn: 0.5348795 total: 618ms remaining: 3.13s 165: learn: 0.5339208 total: 621ms remaining: 3.12s 166: learn: 0.5325131 total: 624ms remaining: 3.11s 167: learn: 0.5314276 total: 628ms remaining: 3.11s 168: learn: 0.5303847 total: 631ms remaining: 3.1s 169: learn: 0.5294339 total: 634ms remaining: 3.09s 170: learn: 0.5284374 total: 637ms remaining: 3.09s 171: learn: 0.5275295 total: 640ms remaining: 3.08s 172: learn: 0.5267194 total: 643ms remaining: 3.08s 173: learn: 0.5256977 total: 647ms remaining: 3.07s 174: learn: 0.5240502 total: 650ms remaining: 3.06s 175: learn: 0.5231121 total: 653ms remaining: 3.06s 176: learn: 0.5219285 total: 656ms remaining: 3.05s 177: learn: 0.5212026 total: 659ms remaining: 3.04s 178: learn: 0.5201152 total: 662ms remaining: 3.04s 179: learn: 0.5191207 total: 665ms remaining: 3.03s 180: learn: 0.5184563 total: 669ms remaining: 3.02s 181: learn: 0.5175407 total: 672ms remaining: 3.02s 182: learn: 0.5164491 total: 674ms remaining: 3.01s 183: learn: 0.5152700 total: 678ms remaining: 3s 184: learn: 0.5144159 total: 681ms remaining: 3s 185: learn: 0.5133076 total: 684ms remaining: 2.99s 186: learn: 0.5120716 total: 688ms remaining: 2.99s 187: learn: 0.5112054 total: 691ms remaining: 2.98s 188: learn: 0.5100507 total: 694ms remaining: 2.98s 189: learn: 0.5090646 total: 697ms remaining: 2.97s 190: learn: 0.5079116 total: 700ms remaining: 2.97s 191: learn: 0.5068761 total: 704ms remaining: 2.96s 192: learn: 0.5059156 total: 707ms remaining: 2.96s 193: learn: 0.5046199 total: 710ms remaining: 2.95s 194: learn: 0.5036700 total: 713ms remaining: 2.94s 195: learn: 0.5030163 total: 716ms remaining: 2.94s 196: learn: 0.5021730 total: 719ms remaining: 2.93s 197: learn: 0.5013725 total: 723ms remaining: 2.93s 198: learn: 0.5003361 total: 726ms remaining: 2.92s 199: learn: 0.4992527 total: 729ms remaining: 2.92s 200: learn: 0.4982242 total: 732ms remaining: 2.91s 201: learn: 0.4969857 total: 735ms remaining: 2.9s 202: learn: 0.4964471 total: 738ms remaining: 2.9s 203: learn: 0.4955578 total: 741ms remaining: 2.89s 204: learn: 0.4948073 total: 744ms remaining: 2.89s 205: learn: 0.4938657 total: 747ms remaining: 2.88s 206: learn: 0.4927562 total: 751ms remaining: 2.88s 207: learn: 0.4922795 total: 754ms remaining: 2.87s 208: learn: 0.4912653 total: 757ms remaining: 2.86s 209: learn: 0.4901704 total: 760ms remaining: 2.86s 210: learn: 0.4890394 total: 763ms remaining: 2.85s 211: learn: 0.4880374 total: 767ms remaining: 2.85s 212: learn: 0.4869604 total: 770ms remaining: 2.84s 213: learn: 0.4857559 total: 773ms remaining: 2.84s 214: learn: 0.4849108 total: 776ms remaining: 2.83s 215: learn: 0.4840655 total: 779ms remaining: 2.83s 216: learn: 0.4832522 total: 783ms remaining: 2.82s 217: learn: 0.4823870 total: 786ms remaining: 2.82s 218: learn: 0.4811295 total: 789ms remaining: 2.81s 219: learn: 0.4800552 total: 792ms remaining: 2.81s 220: learn: 0.4795163 total: 796ms remaining: 2.81s 221: learn: 0.4786248 total: 800ms remaining: 2.8s 222: learn: 0.4778513 total: 803ms remaining: 2.8s 223: learn: 0.4771565 total: 806ms remaining: 2.79s 224: learn: 0.4763662 total: 809ms remaining: 2.79s 225: learn: 0.4752713 total: 813ms remaining: 2.78s 226: learn: 0.4744328 total: 816ms remaining: 2.78s 227: learn: 0.4736467 total: 819ms remaining: 2.77s 228: learn: 0.4726439 total: 822ms remaining: 2.77s 229: learn: 0.4720170 total: 825ms remaining: 2.76s 230: learn: 0.4709224 total: 829ms remaining: 2.76s 231: learn: 0.4703208 total: 832ms remaining: 2.75s 232: learn: 0.4694894 total: 835ms remaining: 2.75s 233: learn: 0.4687498 total: 838ms remaining: 2.74s 234: learn: 0.4679239 total: 841ms remaining: 2.74s 235: learn: 0.4670138 total: 845ms remaining: 2.73s 236: learn: 0.4663092 total: 848ms remaining: 2.73s 237: learn: 0.4655198 total: 851ms remaining: 2.72s 238: learn: 0.4644130 total: 854ms remaining: 2.72s 239: learn: 0.4633542 total: 857ms remaining: 2.71s 240: learn: 0.4624652 total: 860ms remaining: 2.71s 241: learn: 0.4612790 total: 864ms remaining: 2.71s 242: learn: 0.4603322 total: 867ms remaining: 2.7s 243: learn: 0.4595621 total: 870ms remaining: 2.69s 244: learn: 0.4590279 total: 873ms remaining: 2.69s 245: learn: 0.4581471 total: 876ms remaining: 2.69s 246: learn: 0.4575228 total: 880ms remaining: 2.68s 247: learn: 0.4573403 total: 883ms remaining: 2.68s 248: learn: 0.4564604 total: 886ms remaining: 2.67s 249: learn: 0.4556477 total: 889ms remaining: 2.67s 250: learn: 0.4546236 total: 892ms remaining: 2.66s 251: learn: 0.4533804 total: 895ms remaining: 2.66s 252: learn: 0.4528083 total: 898ms remaining: 2.65s 253: learn: 0.4522885 total: 902ms remaining: 2.65s 254: learn: 0.4516671 total: 905ms remaining: 2.64s 255: learn: 0.4509067 total: 908ms remaining: 2.64s 256: learn: 0.4502604 total: 911ms remaining: 2.63s 257: learn: 0.4496600 total: 914ms remaining: 2.63s 258: learn: 0.4490141 total: 917ms remaining: 2.62s 259: learn: 0.4483730 total: 921ms remaining: 2.62s 260: learn: 0.4476013 total: 924ms remaining: 2.62s 261: learn: 0.4464556 total: 928ms remaining: 2.61s 262: learn: 0.4460153 total: 931ms remaining: 2.61s 263: learn: 0.4452759 total: 934ms remaining: 2.6s 264: learn: 0.4448251 total: 937ms remaining: 2.6s 265: learn: 0.4439872 total: 940ms remaining: 2.59s 266: learn: 0.4432581 total: 943ms remaining: 2.59s 267: learn: 0.4425783 total: 946ms remaining: 2.58s 268: learn: 0.4420917 total: 950ms remaining: 2.58s 269: learn: 0.4414548 total: 953ms remaining: 2.58s 270: learn: 0.4406199 total: 956ms remaining: 2.57s 271: learn: 0.4401848 total: 959ms remaining: 2.57s 272: learn: 0.4396981 total: 963ms remaining: 2.56s 273: learn: 0.4391266 total: 966ms remaining: 2.56s 274: learn: 0.4389288 total: 969ms remaining: 2.56s 275: learn: 0.4383482 total: 972ms remaining: 2.55s 276: learn: 0.4374541 total: 976ms remaining: 2.55s 277: learn: 0.4366817 total: 979ms remaining: 2.54s 278: learn: 0.4361637 total: 984ms remaining: 2.54s 279: learn: 0.4355985 total: 987ms remaining: 2.54s 280: learn: 0.4349846 total: 992ms remaining: 2.54s 281: learn: 0.4343480 total: 995ms remaining: 2.53s 282: learn: 0.4340070 total: 999ms remaining: 2.53s 283: learn: 0.4330370 total: 1s remaining: 2.53s 284: learn: 0.4323418 total: 1s remaining: 2.52s 285: learn: 0.4320533 total: 1.01s remaining: 2.52s 286: learn: 0.4312878 total: 1.01s remaining: 2.51s 287: learn: 0.4302917 total: 1.01s remaining: 2.51s 288: learn: 0.4296243 total: 1.02s remaining: 2.5s 289: learn: 0.4289696 total: 1.02s remaining: 2.5s 290: learn: 0.4280839 total: 1.02s remaining: 2.5s 291: learn: 0.4276341 total: 1.03s remaining: 2.49s 292: learn: 0.4273110 total: 1.03s remaining: 2.49s 293: learn: 0.4264972 total: 1.03s remaining: 2.48s 294: learn: 0.4260751 total: 1.04s remaining: 2.48s 295: learn: 0.4250758 total: 1.04s remaining: 2.47s 296: learn: 0.4242247 total: 1.04s remaining: 2.47s 297: learn: 0.4237303 total: 1.05s remaining: 2.46s 298: learn: 0.4229408 total: 1.05s remaining: 2.46s 299: learn: 0.4221990 total: 1.05s remaining: 2.46s 300: learn: 0.4216592 total: 1.06s remaining: 2.45s 301: learn: 0.4215269 total: 1.06s remaining: 2.45s 302: learn: 0.4209301 total: 1.06s remaining: 2.44s 303: learn: 0.4204768 total: 1.06s remaining: 2.44s 304: learn: 0.4199938 total: 1.07s remaining: 2.43s 305: learn: 0.4190819 total: 1.07s remaining: 2.43s 306: learn: 0.4186059 total: 1.07s remaining: 2.43s 307: learn: 0.4179422 total: 1.08s remaining: 2.42s 308: learn: 0.4173496 total: 1.08s remaining: 2.42s 309: learn: 0.4165331 total: 1.08s remaining: 2.41s 310: learn: 0.4156881 total: 1.09s remaining: 2.41s 311: learn: 0.4150547 total: 1.09s remaining: 2.4s 312: learn: 0.4144039 total: 1.09s remaining: 2.4s 313: learn: 0.4137554 total: 1.1s remaining: 2.4s 314: learn: 0.4128423 total: 1.1s remaining: 2.39s 315: learn: 0.4121851 total: 1.1s remaining: 2.39s 316: learn: 0.4120652 total: 1.11s remaining: 2.38s 317: learn: 0.4119489 total: 1.11s remaining: 2.38s 318: learn: 0.4113800 total: 1.11s remaining: 2.37s 319: learn: 0.4110944 total: 1.11s remaining: 2.37s 320: learn: 0.4103885 total: 1.12s remaining: 2.37s 321: learn: 0.4098944 total: 1.12s remaining: 2.36s 322: learn: 0.4096331 total: 1.12s remaining: 2.36s 323: learn: 0.4092385 total: 1.13s remaining: 2.35s 324: learn: 0.4088671 total: 1.13s remaining: 2.35s 325: learn: 0.4085054 total: 1.13s remaining: 2.35s 326: learn: 0.4079698 total: 1.14s remaining: 2.34s 327: learn: 0.4075026 total: 1.14s remaining: 2.34s 328: learn: 0.4065781 total: 1.14s remaining: 2.33s 329: learn: 0.4061529 total: 1.15s remaining: 2.33s 330: learn: 0.4057269 total: 1.15s remaining: 2.32s 331: learn: 0.4050642 total: 1.15s remaining: 2.32s 332: learn: 0.4047541 total: 1.16s remaining: 2.32s 333: learn: 0.4044138 total: 1.16s remaining: 2.31s 334: learn: 0.4040232 total: 1.16s remaining: 2.31s 335: learn: 0.4031800 total: 1.17s remaining: 2.3s 336: learn: 0.4026754 total: 1.17s remaining: 2.3s 337: learn: 0.4022685 total: 1.17s remaining: 2.3s 338: learn: 0.4017881 total: 1.18s remaining: 2.29s 339: learn: 0.4015254 total: 1.18s remaining: 2.29s 340: learn: 0.4009658 total: 1.18s remaining: 2.29s 341: learn: 0.4008354 total: 1.19s remaining: 2.28s 342: learn: 0.4002425 total: 1.19s remaining: 2.28s 343: learn: 0.3997084 total: 1.19s remaining: 2.27s 344: learn: 0.3992679 total: 1.2s remaining: 2.27s 345: learn: 0.3989327 total: 1.2s remaining: 2.27s 346: learn: 0.3983560 total: 1.2s remaining: 2.26s 347: learn: 0.3978115 total: 1.21s remaining: 2.26s 348: learn: 0.3972325 total: 1.21s remaining: 2.25s 349: learn: 0.3970591 total: 1.21s remaining: 2.25s 350: learn: 0.3967434 total: 1.21s remaining: 2.25s 351: learn: 0.3964092 total: 1.22s remaining: 2.24s 352: learn: 0.3959126 total: 1.22s remaining: 2.24s 353: learn: 0.3956668 total: 1.22s remaining: 2.23s 354: learn: 0.3952955 total: 1.23s remaining: 2.23s 355: learn: 0.3948577 total: 1.23s remaining: 2.23s 356: learn: 0.3942103 total: 1.23s remaining: 2.22s 357: learn: 0.3934125 total: 1.24s remaining: 2.22s 358: learn: 0.3929393 total: 1.24s remaining: 2.21s 359: learn: 0.3923515 total: 1.24s remaining: 2.21s 360: learn: 0.3918921 total: 1.25s remaining: 2.21s 361: learn: 0.3913288 total: 1.25s remaining: 2.2s 362: learn: 0.3909709 total: 1.25s remaining: 2.2s 363: learn: 0.3907805 total: 1.25s remaining: 2.19s 364: learn: 0.3902974 total: 1.26s remaining: 2.19s 365: learn: 0.3897694 total: 1.26s remaining: 2.19s 366: learn: 0.3895400 total: 1.26s remaining: 2.18s 367: learn: 0.3887777 total: 1.27s remaining: 2.18s 368: learn: 0.3879556 total: 1.27s remaining: 2.17s 369: learn: 0.3872718 total: 1.27s remaining: 2.17s 370: learn: 0.3866534 total: 1.28s remaining: 2.17s 371: learn: 0.3859462 total: 1.28s remaining: 2.16s 372: learn: 0.3853502 total: 1.28s remaining: 2.16s 373: learn: 0.3851891 total: 1.29s remaining: 2.15s 374: learn: 0.3846603 total: 1.29s remaining: 2.15s 375: learn: 0.3844259 total: 1.29s remaining: 2.15s 376: learn: 0.3838207 total: 1.3s remaining: 2.14s 377: learn: 0.3833191 total: 1.3s remaining: 2.14s 378: learn: 0.3826555 total: 1.3s remaining: 2.13s 379: learn: 0.3825492 total: 1.31s remaining: 2.13s 380: learn: 0.3819978 total: 1.31s remaining: 2.13s 381: learn: 0.3819036 total: 1.31s remaining: 2.12s 382: learn: 0.3814676 total: 1.32s remaining: 2.12s 383: learn: 0.3806646 total: 1.32s remaining: 2.12s 384: learn: 0.3801625 total: 1.32s remaining: 2.11s 385: learn: 0.3795941 total: 1.32s remaining: 2.11s 386: learn: 0.3792967 total: 1.33s remaining: 2.1s 387: learn: 0.3786377 total: 1.33s remaining: 2.1s 388: learn: 0.3781177 total: 1.34s remaining: 2.1s 389: learn: 0.3776432 total: 1.34s remaining: 2.1s 390: learn: 0.3771742 total: 1.34s remaining: 2.09s 391: learn: 0.3767596 total: 1.35s remaining: 2.09s 392: learn: 0.3762086 total: 1.35s remaining: 2.09s 393: learn: 0.3758127 total: 1.35s remaining: 2.08s 394: learn: 0.3752773 total: 1.36s remaining: 2.08s 395: learn: 0.3748929 total: 1.36s remaining: 2.08s 396: learn: 0.3744049 total: 1.36s remaining: 2.07s 397: learn: 0.3742496 total: 1.37s remaining: 2.07s 398: learn: 0.3737786 total: 1.37s remaining: 2.07s 399: learn: 0.3734626 total: 1.38s remaining: 2.07s 400: learn: 0.3730718 total: 1.38s remaining: 2.06s 401: learn: 0.3724141 total: 1.38s remaining: 2.06s 402: learn: 0.3719930 total: 1.39s remaining: 2.05s 403: learn: 0.3716086 total: 1.39s remaining: 2.05s 404: learn: 0.3709842 total: 1.39s remaining: 2.05s 405: learn: 0.3704663 total: 1.4s remaining: 2.04s 406: learn: 0.3699775 total: 1.4s remaining: 2.04s 407: learn: 0.3698619 total: 1.4s remaining: 2.04s 408: learn: 0.3694625 total: 1.41s remaining: 2.03s 409: learn: 0.3688800 total: 1.41s remaining: 2.03s 410: learn: 0.3687014 total: 1.41s remaining: 2.02s 411: learn: 0.3682739 total: 1.42s remaining: 2.02s 412: learn: 0.3678543 total: 1.42s remaining: 2.02s 413: learn: 0.3672088 total: 1.42s remaining: 2.01s 414: learn: 0.3667914 total: 1.43s remaining: 2.01s 415: learn: 0.3666853 total: 1.43s remaining: 2.01s 416: learn: 0.3664491 total: 1.43s remaining: 2s 417: learn: 0.3658085 total: 1.44s remaining: 2s 418: learn: 0.3655759 total: 1.44s remaining: 1.99s 419: learn: 0.3650955 total: 1.44s remaining: 1.99s 420: learn: 0.3646143 total: 1.45s remaining: 1.99s 421: learn: 0.3643238 total: 1.45s remaining: 1.98s 422: learn: 0.3640482 total: 1.45s remaining: 1.98s 423: learn: 0.3638433 total: 1.45s remaining: 1.98s 424: learn: 0.3634363 total: 1.46s remaining: 1.97s 425: learn: 0.3627563 total: 1.46s remaining: 1.97s 426: learn: 0.3620870 total: 1.46s remaining: 1.97s 427: learn: 0.3615678 total: 1.47s remaining: 1.96s 428: learn: 0.3608775 total: 1.47s remaining: 1.96s 429: learn: 0.3605767 total: 1.47s remaining: 1.95s 430: learn: 0.3600947 total: 1.48s remaining: 1.95s 431: learn: 0.3599022 total: 1.48s remaining: 1.95s 432: learn: 0.3592169 total: 1.48s remaining: 1.94s 433: learn: 0.3586974 total: 1.49s remaining: 1.94s 434: learn: 0.3583809 total: 1.49s remaining: 1.93s 435: learn: 0.3580570 total: 1.49s remaining: 1.93s 436: learn: 0.3579281 total: 1.5s remaining: 1.93s 437: learn: 0.3572908 total: 1.5s remaining: 1.92s 438: learn: 0.3567620 total: 1.5s remaining: 1.92s 439: learn: 0.3563369 total: 1.5s remaining: 1.92s 440: learn: 0.3559366 total: 1.51s remaining: 1.91s 441: learn: 0.3555465 total: 1.51s remaining: 1.91s 442: learn: 0.3551928 total: 1.51s remaining: 1.9s 443: learn: 0.3546176 total: 1.52s remaining: 1.9s 444: learn: 0.3542245 total: 1.52s remaining: 1.9s 445: learn: 0.3537539 total: 1.52s remaining: 1.89s 446: learn: 0.3533939 total: 1.53s remaining: 1.89s 447: learn: 0.3530292 total: 1.53s remaining: 1.89s 448: learn: 0.3527625 total: 1.53s remaining: 1.88s 449: learn: 0.3523223 total: 1.54s remaining: 1.88s 450: learn: 0.3522093 total: 1.54s remaining: 1.88s 451: learn: 0.3515991 total: 1.54s remaining: 1.87s 452: learn: 0.3513537 total: 1.55s remaining: 1.87s 453: learn: 0.3508370 total: 1.55s remaining: 1.86s 454: learn: 0.3505712 total: 1.55s remaining: 1.86s 455: learn: 0.3500935 total: 1.56s remaining: 1.86s 456: learn: 0.3497203 total: 1.56s remaining: 1.85s 457: learn: 0.3493516 total: 1.56s remaining: 1.85s 458: learn: 0.3493006 total: 1.57s remaining: 1.85s 459: learn: 0.3492219 total: 1.57s remaining: 1.84s 460: learn: 0.3486440 total: 1.57s remaining: 1.84s 461: learn: 0.3481056 total: 1.58s remaining: 1.84s 462: learn: 0.3476578 total: 1.58s remaining: 1.83s 463: learn: 0.3472720 total: 1.58s remaining: 1.83s 464: learn: 0.3469809 total: 1.59s remaining: 1.83s 465: learn: 0.3464662 total: 1.59s remaining: 1.82s 466: learn: 0.3461573 total: 1.59s remaining: 1.82s 467: learn: 0.3455262 total: 1.6s remaining: 1.81s 468: learn: 0.3450315 total: 1.6s remaining: 1.81s 469: learn: 0.3447621 total: 1.6s remaining: 1.81s 470: learn: 0.3443707 total: 1.61s remaining: 1.8s 471: learn: 0.3440078 total: 1.61s remaining: 1.8s 472: learn: 0.3436588 total: 1.61s remaining: 1.8s 473: learn: 0.3434317 total: 1.61s remaining: 1.79s 474: learn: 0.3429587 total: 1.62s remaining: 1.79s 475: learn: 0.3427129 total: 1.62s remaining: 1.78s 476: learn: 0.3426614 total: 1.63s remaining: 1.78s 477: learn: 0.3424276 total: 1.63s remaining: 1.78s 478: learn: 0.3419850 total: 1.63s remaining: 1.77s 479: learn: 0.3414828 total: 1.63s remaining: 1.77s 480: learn: 0.3414148 total: 1.64s remaining: 1.77s 481: learn: 0.3411039 total: 1.64s remaining: 1.76s 482: learn: 0.3404094 total: 1.64s remaining: 1.76s 483: learn: 0.3400870 total: 1.65s remaining: 1.76s 484: learn: 0.3397870 total: 1.65s remaining: 1.75s 485: learn: 0.3395923 total: 1.65s remaining: 1.75s 486: learn: 0.3393283 total: 1.66s remaining: 1.75s 487: learn: 0.3390542 total: 1.66s remaining: 1.74s 488: learn: 0.3387735 total: 1.66s remaining: 1.74s 489: learn: 0.3384992 total: 1.67s remaining: 1.74s 490: learn: 0.3382679 total: 1.67s remaining: 1.73s 491: learn: 0.3379591 total: 1.67s remaining: 1.73s 492: learn: 0.3375240 total: 1.68s remaining: 1.72s 493: learn: 0.3371663 total: 1.68s remaining: 1.72s 494: learn: 0.3368073 total: 1.68s remaining: 1.72s 495: learn: 0.3364258 total: 1.69s remaining: 1.71s 496: learn: 0.3357624 total: 1.69s remaining: 1.71s 497: learn: 0.3353391 total: 1.69s remaining: 1.71s 498: learn: 0.3351359 total: 1.7s remaining: 1.7s 499: learn: 0.3347721 total: 1.7s remaining: 1.7s 500: learn: 0.3344689 total: 1.7s remaining: 1.7s 501: learn: 0.3338885 total: 1.71s remaining: 1.69s 502: learn: 0.3335304 total: 1.71s remaining: 1.69s 503: learn: 0.3333156 total: 1.71s remaining: 1.68s 504: learn: 0.3326858 total: 1.71s remaining: 1.68s 505: learn: 0.3322794 total: 1.72s remaining: 1.68s 506: learn: 0.3317957 total: 1.72s remaining: 1.67s 507: learn: 0.3315374 total: 1.72s remaining: 1.67s 508: learn: 0.3310346 total: 1.73s remaining: 1.67s 509: learn: 0.3308218 total: 1.73s remaining: 1.66s 510: learn: 0.3305515 total: 1.73s remaining: 1.66s 511: learn: 0.3303324 total: 1.74s remaining: 1.66s 512: learn: 0.3298027 total: 1.74s remaining: 1.65s 513: learn: 0.3295265 total: 1.74s remaining: 1.65s 514: learn: 0.3292161 total: 1.75s remaining: 1.64s 515: learn: 0.3289082 total: 1.75s remaining: 1.64s 516: learn: 0.3284967 total: 1.75s remaining: 1.64s 517: learn: 0.3280931 total: 1.75s remaining: 1.63s 518: learn: 0.3277610 total: 1.76s remaining: 1.63s 519: learn: 0.3273487 total: 1.76s remaining: 1.63s 520: learn: 0.3270085 total: 1.76s remaining: 1.62s 521: learn: 0.3267775 total: 1.77s remaining: 1.62s 522: learn: 0.3265932 total: 1.77s remaining: 1.62s 523: learn: 0.3262778 total: 1.78s remaining: 1.61s 524: learn: 0.3256224 total: 1.78s remaining: 1.61s 525: learn: 0.3253104 total: 1.78s remaining: 1.61s 526: learn: 0.3252620 total: 1.78s remaining: 1.6s 527: learn: 0.3248235 total: 1.79s remaining: 1.6s 528: learn: 0.3246449 total: 1.79s remaining: 1.59s 529: learn: 0.3241613 total: 1.79s remaining: 1.59s 530: learn: 0.3239438 total: 1.8s remaining: 1.59s 531: learn: 0.3235460 total: 1.8s remaining: 1.58s 532: learn: 0.3228695 total: 1.8s remaining: 1.58s 533: learn: 0.3226564 total: 1.81s remaining: 1.58s 534: learn: 0.3223583 total: 1.81s remaining: 1.57s 535: learn: 0.3221698 total: 1.81s remaining: 1.57s 536: learn: 0.3216386 total: 1.82s remaining: 1.57s 537: learn: 0.3208786 total: 1.82s remaining: 1.56s 538: learn: 0.3206405 total: 1.82s remaining: 1.56s 539: learn: 0.3200470 total: 1.83s remaining: 1.56s 540: learn: 0.3194599 total: 1.83s remaining: 1.55s 541: learn: 0.3191917 total: 1.83s remaining: 1.55s 542: learn: 0.3187413 total: 1.84s remaining: 1.54s 543: learn: 0.3183845 total: 1.84s remaining: 1.54s 544: learn: 0.3181347 total: 1.84s remaining: 1.54s 545: learn: 0.3179402 total: 1.84s remaining: 1.53s 546: learn: 0.3176041 total: 1.85s remaining: 1.53s 547: learn: 0.3171493 total: 1.85s remaining: 1.53s 548: learn: 0.3167678 total: 1.85s remaining: 1.52s 549: learn: 0.3161866 total: 1.86s remaining: 1.52s 550: learn: 0.3158943 total: 1.86s remaining: 1.52s 551: learn: 0.3155500 total: 1.86s remaining: 1.51s 552: learn: 0.3151446 total: 1.87s remaining: 1.51s 553: learn: 0.3150676 total: 1.87s remaining: 1.51s 554: learn: 0.3148019 total: 1.87s remaining: 1.5s 555: learn: 0.3146087 total: 1.88s remaining: 1.5s 556: learn: 0.3145459 total: 1.88s remaining: 1.5s 557: learn: 0.3144546 total: 1.88s remaining: 1.49s 558: learn: 0.3140608 total: 1.89s remaining: 1.49s 559: learn: 0.3137361 total: 1.89s remaining: 1.49s 560: learn: 0.3134164 total: 1.89s remaining: 1.48s 561: learn: 0.3133328 total: 1.9s remaining: 1.48s 562: learn: 0.3129987 total: 1.9s remaining: 1.47s 563: learn: 0.3127488 total: 1.9s remaining: 1.47s 564: learn: 0.3124070 total: 1.91s remaining: 1.47s 565: learn: 0.3120479 total: 1.91s remaining: 1.46s 566: learn: 0.3118329 total: 1.91s remaining: 1.46s 567: learn: 0.3115109 total: 1.92s remaining: 1.46s 568: learn: 0.3112201 total: 1.92s remaining: 1.45s 569: learn: 0.3108064 total: 1.92s remaining: 1.45s 570: learn: 0.3105657 total: 1.92s remaining: 1.45s 571: learn: 0.3102145 total: 1.93s remaining: 1.44s 572: learn: 0.3097385 total: 1.93s remaining: 1.44s 573: learn: 0.3093835 total: 1.93s remaining: 1.44s 574: learn: 0.3090351 total: 1.94s remaining: 1.43s 575: learn: 0.3087266 total: 1.94s remaining: 1.43s 576: learn: 0.3085897 total: 1.94s remaining: 1.43s 577: learn: 0.3083184 total: 1.95s remaining: 1.42s 578: learn: 0.3082340 total: 1.95s remaining: 1.42s 579: learn: 0.3076160 total: 1.95s remaining: 1.41s 580: learn: 0.3073862 total: 1.96s remaining: 1.41s 581: learn: 0.3070794 total: 1.96s remaining: 1.41s 582: learn: 0.3069234 total: 1.96s remaining: 1.4s 583: learn: 0.3068346 total: 1.97s remaining: 1.4s 584: learn: 0.3065929 total: 1.97s remaining: 1.4s 585: learn: 0.3063195 total: 1.98s remaining: 1.4s 586: learn: 0.3059346 total: 1.98s remaining: 1.39s 587: learn: 0.3056757 total: 1.98s remaining: 1.39s 588: learn: 0.3053577 total: 1.99s remaining: 1.39s 589: learn: 0.3049893 total: 1.99s remaining: 1.38s 590: learn: 0.3047812 total: 1.99s remaining: 1.38s 591: learn: 0.3046023 total: 2s remaining: 1.38s 592: learn: 0.3042608 total: 2s remaining: 1.37s 593: learn: 0.3039681 total: 2s remaining: 1.37s 594: learn: 0.3035661 total: 2.01s remaining: 1.37s 595: learn: 0.3033736 total: 2.01s remaining: 1.36s 596: learn: 0.3029874 total: 2.01s remaining: 1.36s 597: learn: 0.3023723 total: 2.02s remaining: 1.36s 598: learn: 0.3020138 total: 2.02s remaining: 1.35s 599: learn: 0.3017719 total: 2.02s remaining: 1.35s 600: learn: 0.3015049 total: 2.03s remaining: 1.35s 601: learn: 0.3011821 total: 2.03s remaining: 1.34s 602: learn: 0.3009602 total: 2.03s remaining: 1.34s 603: learn: 0.3005602 total: 2.04s remaining: 1.34s 604: learn: 0.3004219 total: 2.04s remaining: 1.33s 605: learn: 0.3002263 total: 2.04s remaining: 1.33s 606: learn: 0.2999867 total: 2.05s remaining: 1.32s 607: learn: 0.2994174 total: 2.05s remaining: 1.32s 608: learn: 0.2991720 total: 2.05s remaining: 1.32s 609: learn: 0.2987068 total: 2.06s remaining: 1.31s 610: learn: 0.2985418 total: 2.06s remaining: 1.31s 611: learn: 0.2981738 total: 2.06s remaining: 1.31s 612: learn: 0.2979976 total: 2.07s remaining: 1.3s 613: learn: 0.2976747 total: 2.07s remaining: 1.3s 614: learn: 0.2973307 total: 2.07s remaining: 1.3s 615: learn: 0.2970624 total: 2.08s remaining: 1.29s 616: learn: 0.2967306 total: 2.08s remaining: 1.29s 617: learn: 0.2964076 total: 2.08s remaining: 1.29s 618: learn: 0.2962058 total: 2.09s remaining: 1.28s 619: learn: 0.2957920 total: 2.09s remaining: 1.28s 620: learn: 0.2954920 total: 2.09s remaining: 1.28s 621: learn: 0.2950938 total: 2.1s remaining: 1.27s 622: learn: 0.2948042 total: 2.1s remaining: 1.27s 623: learn: 0.2943325 total: 2.1s remaining: 1.27s 624: learn: 0.2940090 total: 2.11s remaining: 1.26s 625: learn: 0.2934846 total: 2.11s remaining: 1.26s 626: learn: 0.2930588 total: 2.11s remaining: 1.26s 627: learn: 0.2928589 total: 2.12s remaining: 1.25s 628: learn: 0.2924964 total: 2.12s remaining: 1.25s 629: learn: 0.2924495 total: 2.12s remaining: 1.25s 630: learn: 0.2920443 total: 2.13s remaining: 1.24s 631: learn: 0.2916802 total: 2.13s remaining: 1.24s 632: learn: 0.2914135 total: 2.13s remaining: 1.24s 633: learn: 0.2912321 total: 2.13s remaining: 1.23s 634: learn: 0.2909848 total: 2.14s remaining: 1.23s 635: learn: 0.2904671 total: 2.14s remaining: 1.23s 636: learn: 0.2902028 total: 2.15s remaining: 1.22s 637: learn: 0.2898836 total: 2.15s remaining: 1.22s 638: learn: 0.2895101 total: 2.15s remaining: 1.22s 639: learn: 0.2892208 total: 2.15s remaining: 1.21s 640: learn: 0.2891015 total: 2.16s remaining: 1.21s 641: learn: 0.2888759 total: 2.16s remaining: 1.21s 642: learn: 0.2887283 total: 2.17s remaining: 1.2s 643: learn: 0.2885609 total: 2.17s remaining: 1.2s 644: learn: 0.2881727 total: 2.17s remaining: 1.2s 645: learn: 0.2877596 total: 2.17s remaining: 1.19s 646: learn: 0.2874737 total: 2.18s remaining: 1.19s 647: learn: 0.2870647 total: 2.18s remaining: 1.19s 648: learn: 0.2865679 total: 2.18s remaining: 1.18s 649: learn: 0.2863326 total: 2.19s remaining: 1.18s 650: learn: 0.2861018 total: 2.19s remaining: 1.17s 651: learn: 0.2857576 total: 2.19s remaining: 1.17s 652: learn: 0.2853869 total: 2.2s remaining: 1.17s 653: learn: 0.2851109 total: 2.2s remaining: 1.16s 654: learn: 0.2848755 total: 2.2s remaining: 1.16s 655: learn: 0.2847261 total: 2.21s remaining: 1.16s 656: learn: 0.2844077 total: 2.21s remaining: 1.15s 657: learn: 0.2840529 total: 2.21s remaining: 1.15s 658: learn: 0.2838391 total: 2.22s remaining: 1.15s 659: learn: 0.2833714 total: 2.22s remaining: 1.14s 660: learn: 0.2831854 total: 2.22s remaining: 1.14s 661: learn: 0.2831066 total: 2.23s remaining: 1.14s 662: learn: 0.2827877 total: 2.23s remaining: 1.13s 663: learn: 0.2825169 total: 2.23s remaining: 1.13s 664: learn: 0.2819040 total: 2.23s remaining: 1.13s 665: learn: 0.2815437 total: 2.24s remaining: 1.12s 666: learn: 0.2812914 total: 2.24s remaining: 1.12s 667: learn: 0.2809317 total: 2.25s remaining: 1.11s 668: learn: 0.2805022 total: 2.25s remaining: 1.11s 669: learn: 0.2802108 total: 2.25s remaining: 1.11s 670: learn: 0.2800260 total: 2.25s remaining: 1.1s 671: learn: 0.2796513 total: 2.26s remaining: 1.1s 672: learn: 0.2793144 total: 2.26s remaining: 1.1s 673: learn: 0.2787577 total: 2.26s remaining: 1.09s 674: learn: 0.2786898 total: 2.27s remaining: 1.09s 675: learn: 0.2783373 total: 2.27s remaining: 1.09s 676: learn: 0.2780719 total: 2.27s remaining: 1.08s 677: learn: 0.2778878 total: 2.28s remaining: 1.08s 678: learn: 0.2775800 total: 2.28s remaining: 1.08s 679: learn: 0.2774185 total: 2.28s remaining: 1.07s 680: learn: 0.2773129 total: 2.29s remaining: 1.07s 681: learn: 0.2770925 total: 2.29s remaining: 1.07s 682: learn: 0.2768828 total: 2.29s remaining: 1.06s 683: learn: 0.2768476 total: 2.29s remaining: 1.06s 684: learn: 0.2767953 total: 2.3s remaining: 1.06s 685: learn: 0.2764390 total: 2.3s remaining: 1.05s 686: learn: 0.2762438 total: 2.31s remaining: 1.05s 687: learn: 0.2758780 total: 2.31s remaining: 1.05s 688: learn: 0.2755027 total: 2.31s remaining: 1.04s 689: learn: 0.2751770 total: 2.31s remaining: 1.04s 690: learn: 0.2747903 total: 2.32s remaining: 1.04s 691: learn: 0.2744635 total: 2.32s remaining: 1.03s 692: learn: 0.2740074 total: 2.32s remaining: 1.03s 693: learn: 0.2739810 total: 2.33s remaining: 1.03s 694: learn: 0.2736186 total: 2.33s remaining: 1.02s 695: learn: 0.2734820 total: 2.33s remaining: 1.02s 696: learn: 0.2732142 total: 2.34s remaining: 1.02s 697: learn: 0.2729981 total: 2.34s remaining: 1.01s 698: learn: 0.2728261 total: 2.34s remaining: 1.01s 699: learn: 0.2725844 total: 2.35s remaining: 1.01s 700: learn: 0.2721828 total: 2.35s remaining: 1s 701: learn: 0.2719829 total: 2.35s remaining: 999ms 702: learn: 0.2719228 total: 2.36s remaining: 996ms 703: learn: 0.2716038 total: 2.36s remaining: 993ms 704: learn: 0.2714065 total: 2.36s remaining: 990ms 705: learn: 0.2710144 total: 2.37s remaining: 986ms 706: learn: 0.2709227 total: 2.37s remaining: 983ms 707: learn: 0.2707329 total: 2.37s remaining: 979ms 708: learn: 0.2704630 total: 2.38s remaining: 976ms 709: learn: 0.2702448 total: 2.38s remaining: 972ms 710: learn: 0.2700055 total: 2.38s remaining: 969ms 711: learn: 0.2694917 total: 2.39s remaining: 966ms 712: learn: 0.2693000 total: 2.39s remaining: 962ms 713: learn: 0.2690283 total: 2.39s remaining: 959ms 714: learn: 0.2687780 total: 2.4s remaining: 955ms 715: learn: 0.2687384 total: 2.4s remaining: 952ms 716: learn: 0.2685078 total: 2.4s remaining: 948ms 717: learn: 0.2681714 total: 2.41s remaining: 945ms 718: learn: 0.2679557 total: 2.41s remaining: 942ms 719: learn: 0.2678159 total: 2.41s remaining: 938ms 720: learn: 0.2676483 total: 2.42s remaining: 935ms 721: learn: 0.2674793 total: 2.42s remaining: 931ms 722: learn: 0.2669124 total: 2.42s remaining: 928ms 723: learn: 0.2666120 total: 2.42s remaining: 925ms 724: learn: 0.2665815 total: 2.43s remaining: 921ms 725: learn: 0.2662691 total: 2.43s remaining: 918ms 726: learn: 0.2661253 total: 2.43s remaining: 914ms 727: learn: 0.2659665 total: 2.44s remaining: 911ms 728: learn: 0.2656264 total: 2.44s remaining: 907ms 729: learn: 0.2654130 total: 2.44s remaining: 904ms 730: learn: 0.2652073 total: 2.45s remaining: 900ms 731: learn: 0.2649612 total: 2.45s remaining: 897ms 732: learn: 0.2648599 total: 2.45s remaining: 894ms 733: learn: 0.2646590 total: 2.46s remaining: 890ms 734: learn: 0.2644626 total: 2.46s remaining: 887ms 735: learn: 0.2641833 total: 2.46s remaining: 883ms 736: learn: 0.2639876 total: 2.46s remaining: 880ms 737: learn: 0.2637405 total: 2.47s remaining: 877ms 738: learn: 0.2636841 total: 2.47s remaining: 873ms 739: learn: 0.2632669 total: 2.48s remaining: 870ms 740: learn: 0.2629340 total: 2.48s remaining: 866ms 741: learn: 0.2624565 total: 2.48s remaining: 863ms 742: learn: 0.2621763 total: 2.48s remaining: 860ms 743: learn: 0.2620465 total: 2.49s remaining: 856ms 744: learn: 0.2617521 total: 2.49s remaining: 853ms 745: learn: 0.2616214 total: 2.49s remaining: 849ms 746: learn: 0.2613156 total: 2.5s remaining: 846ms 747: learn: 0.2610295 total: 2.5s remaining: 843ms 748: learn: 0.2605249 total: 2.5s remaining: 839ms 749: learn: 0.2602300 total: 2.51s remaining: 836ms 750: learn: 0.2598565 total: 2.51s remaining: 832ms 751: learn: 0.2595145 total: 2.51s remaining: 829ms 752: learn: 0.2590396 total: 2.52s remaining: 826ms 753: learn: 0.2586583 total: 2.52s remaining: 822ms 754: learn: 0.2583753 total: 2.52s remaining: 819ms 755: learn: 0.2581020 total: 2.53s remaining: 815ms 756: learn: 0.2576964 total: 2.53s remaining: 813ms 757: learn: 0.2576649 total: 2.53s remaining: 809ms 758: learn: 0.2572911 total: 2.54s remaining: 806ms 759: learn: 0.2572575 total: 2.54s remaining: 803ms 760: learn: 0.2569378 total: 2.54s remaining: 799ms 761: learn: 0.2567645 total: 2.55s remaining: 796ms 762: learn: 0.2565097 total: 2.55s remaining: 793ms 763: learn: 0.2560526 total: 2.56s remaining: 790ms 764: learn: 0.2560135 total: 2.56s remaining: 786ms 765: learn: 0.2555596 total: 2.56s remaining: 783ms 766: learn: 0.2553998 total: 2.56s remaining: 779ms 767: learn: 0.2551371 total: 2.57s remaining: 776ms 768: learn: 0.2550625 total: 2.57s remaining: 773ms 769: learn: 0.2547166 total: 2.58s remaining: 769ms 770: learn: 0.2545018 total: 2.58s remaining: 766ms 771: learn: 0.2541630 total: 2.58s remaining: 763ms 772: learn: 0.2537533 total: 2.59s remaining: 760ms 773: learn: 0.2536258 total: 2.59s remaining: 756ms 774: learn: 0.2536134 total: 2.59s remaining: 753ms 775: learn: 0.2535560 total: 2.6s remaining: 749ms 776: learn: 0.2533754 total: 2.6s remaining: 746ms 777: learn: 0.2529591 total: 2.6s remaining: 743ms 778: learn: 0.2527568 total: 2.6s remaining: 739ms 779: learn: 0.2523681 total: 2.61s remaining: 736ms 780: learn: 0.2519061 total: 2.61s remaining: 733ms 781: learn: 0.2516554 total: 2.62s remaining: 729ms 782: learn: 0.2516155 total: 2.62s remaining: 726ms 783: learn: 0.2514085 total: 2.62s remaining: 722ms 784: learn: 0.2512819 total: 2.63s remaining: 719ms 785: learn: 0.2511030 total: 2.63s remaining: 716ms 786: learn: 0.2508719 total: 2.63s remaining: 712ms 787: learn: 0.2507493 total: 2.63s remaining: 709ms 788: learn: 0.2506815 total: 2.64s remaining: 705ms 789: learn: 0.2506345 total: 2.64s remaining: 702ms 790: learn: 0.2503126 total: 2.64s remaining: 699ms 791: learn: 0.2500426 total: 2.65s remaining: 695ms 792: learn: 0.2498370 total: 2.65s remaining: 692ms 793: learn: 0.2496517 total: 2.65s remaining: 688ms 794: learn: 0.2492671 total: 2.66s remaining: 685ms 795: learn: 0.2490305 total: 2.66s remaining: 682ms 796: learn: 0.2487405 total: 2.66s remaining: 678ms 797: learn: 0.2485630 total: 2.67s remaining: 675ms 798: learn: 0.2485411 total: 2.67s remaining: 671ms 799: learn: 0.2482743 total: 2.67s remaining: 668ms 800: learn: 0.2479895 total: 2.67s remaining: 665ms 801: learn: 0.2476468 total: 2.68s remaining: 661ms 802: learn: 0.2476272 total: 2.68s remaining: 658ms 803: learn: 0.2474233 total: 2.68s remaining: 654ms 804: learn: 0.2471691 total: 2.69s remaining: 651ms 805: learn: 0.2470093 total: 2.69s remaining: 648ms 806: learn: 0.2466640 total: 2.69s remaining: 644ms 807: learn: 0.2464114 total: 2.7s remaining: 641ms 808: learn: 0.2462497 total: 2.7s remaining: 637ms 809: learn: 0.2458715 total: 2.7s remaining: 634ms 810: learn: 0.2455936 total: 2.71s remaining: 631ms 811: learn: 0.2454435 total: 2.71s remaining: 627ms 812: learn: 0.2452169 total: 2.71s remaining: 624ms 813: learn: 0.2451734 total: 2.72s remaining: 621ms 814: learn: 0.2448680 total: 2.72s remaining: 617ms 815: learn: 0.2448261 total: 2.72s remaining: 614ms 816: learn: 0.2443975 total: 2.73s remaining: 611ms 817: learn: 0.2443746 total: 2.73s remaining: 607ms 818: learn: 0.2440732 total: 2.73s remaining: 604ms 819: learn: 0.2439608 total: 2.73s remaining: 600ms 820: learn: 0.2438886 total: 2.74s remaining: 597ms 821: learn: 0.2437090 total: 2.74s remaining: 594ms 822: learn: 0.2436873 total: 2.75s remaining: 590ms 823: learn: 0.2434946 total: 2.75s remaining: 587ms 824: learn: 0.2431038 total: 2.75s remaining: 584ms 825: learn: 0.2428742 total: 2.76s remaining: 581ms 826: learn: 0.2425255 total: 2.76s remaining: 577ms 827: learn: 0.2423590 total: 2.76s remaining: 574ms 828: learn: 0.2418667 total: 2.77s remaining: 570ms 829: learn: 0.2415327 total: 2.77s remaining: 567ms 830: learn: 0.2412411 total: 2.77s remaining: 564ms 831: learn: 0.2408410 total: 2.77s remaining: 560ms 832: learn: 0.2407998 total: 2.78s remaining: 557ms 833: learn: 0.2406402 total: 2.78s remaining: 554ms 834: learn: 0.2406208 total: 2.78s remaining: 550ms 835: learn: 0.2405666 total: 2.79s remaining: 547ms 836: learn: 0.2403029 total: 2.79s remaining: 543ms 837: learn: 0.2401250 total: 2.79s remaining: 540ms 838: learn: 0.2401027 total: 2.8s remaining: 537ms 839: learn: 0.2398424 total: 2.8s remaining: 533ms 840: learn: 0.2398274 total: 2.8s remaining: 530ms 841: learn: 0.2395215 total: 2.81s remaining: 527ms 842: learn: 0.2392416 total: 2.81s remaining: 523ms 843: learn: 0.2390250 total: 2.81s remaining: 520ms 844: learn: 0.2388586 total: 2.82s remaining: 517ms 845: learn: 0.2386279 total: 2.82s remaining: 513ms 846: learn: 0.2384384 total: 2.82s remaining: 510ms 847: learn: 0.2383070 total: 2.83s remaining: 506ms 848: learn: 0.2381507 total: 2.83s remaining: 503ms 849: learn: 0.2379354 total: 2.83s remaining: 500ms 850: learn: 0.2376864 total: 2.83s remaining: 496ms 851: learn: 0.2373329 total: 2.84s remaining: 493ms 852: learn: 0.2372335 total: 2.84s remaining: 490ms 853: learn: 0.2367554 total: 2.84s remaining: 486ms 854: learn: 0.2365606 total: 2.85s remaining: 483ms 855: learn: 0.2365435 total: 2.85s remaining: 480ms 856: learn: 0.2362393 total: 2.85s remaining: 476ms 857: learn: 0.2361009 total: 2.86s remaining: 473ms 858: learn: 0.2360136 total: 2.86s remaining: 470ms 859: learn: 0.2357448 total: 2.86s remaining: 466ms 860: learn: 0.2354335 total: 2.87s remaining: 463ms 861: learn: 0.2353438 total: 2.87s remaining: 459ms 862: learn: 0.2353260 total: 2.87s remaining: 456ms 863: learn: 0.2348839 total: 2.88s remaining: 453ms 864: learn: 0.2347374 total: 2.88s remaining: 449ms 865: learn: 0.2345540 total: 2.88s remaining: 446ms 866: learn: 0.2343686 total: 2.89s remaining: 443ms 867: learn: 0.2341852 total: 2.89s remaining: 439ms 868: learn: 0.2339969 total: 2.89s remaining: 436ms 869: learn: 0.2339190 total: 2.9s remaining: 433ms 870: learn: 0.2337132 total: 2.9s remaining: 429ms 871: learn: 0.2333663 total: 2.9s remaining: 426ms 872: learn: 0.2331696 total: 2.9s remaining: 423ms 873: learn: 0.2329565 total: 2.91s remaining: 419ms 874: learn: 0.2326612 total: 2.91s remaining: 416ms 875: learn: 0.2323790 total: 2.91s remaining: 413ms 876: learn: 0.2322076 total: 2.92s remaining: 409ms 877: learn: 0.2319440 total: 2.92s remaining: 406ms 878: learn: 0.2317660 total: 2.92s remaining: 403ms 879: learn: 0.2317280 total: 2.93s remaining: 399ms 880: learn: 0.2317000 total: 2.93s remaining: 396ms 881: learn: 0.2315801 total: 2.93s remaining: 393ms 882: learn: 0.2315676 total: 2.94s remaining: 389ms 883: learn: 0.2314147 total: 2.94s remaining: 386ms 884: learn: 0.2312086 total: 2.94s remaining: 383ms 885: learn: 0.2308071 total: 2.95s remaining: 379ms 886: learn: 0.2304809 total: 2.95s remaining: 376ms 887: learn: 0.2302192 total: 2.95s remaining: 373ms 888: learn: 0.2299117 total: 2.96s remaining: 369ms 889: learn: 0.2297213 total: 2.96s remaining: 366ms 890: learn: 0.2296087 total: 2.96s remaining: 363ms 891: learn: 0.2294252 total: 2.97s remaining: 359ms 892: learn: 0.2292066 total: 2.97s remaining: 356ms 893: learn: 0.2288881 total: 2.97s remaining: 353ms 894: learn: 0.2286894 total: 2.98s remaining: 349ms 895: learn: 0.2284812 total: 2.98s remaining: 346ms 896: learn: 0.2281809 total: 2.98s remaining: 343ms 897: learn: 0.2279510 total: 2.99s remaining: 339ms 898: learn: 0.2279389 total: 2.99s remaining: 336ms 899: learn: 0.2277300 total: 2.99s remaining: 333ms 900: learn: 0.2274465 total: 3s remaining: 329ms 901: learn: 0.2274025 total: 3s remaining: 326ms 902: learn: 0.2272341 total: 3s remaining: 323ms 903: learn: 0.2272163 total: 3s remaining: 319ms 904: learn: 0.2270345 total: 3.01s remaining: 316ms 905: learn: 0.2270112 total: 3.01s remaining: 313ms 906: learn: 0.2267261 total: 3.02s remaining: 309ms 907: learn: 0.2265868 total: 3.02s remaining: 306ms 908: learn: 0.2262635 total: 3.02s remaining: 303ms 909: learn: 0.2261474 total: 3.03s remaining: 300ms 910: learn: 0.2261271 total: 3.03s remaining: 296ms 911: learn: 0.2259786 total: 3.03s remaining: 293ms 912: learn: 0.2257185 total: 3.04s remaining: 290ms 913: learn: 0.2255092 total: 3.04s remaining: 286ms 914: learn: 0.2253473 total: 3.04s remaining: 283ms 915: learn: 0.2252805 total: 3.05s remaining: 280ms 916: learn: 0.2252673 total: 3.05s remaining: 276ms 917: learn: 0.2250309 total: 3.05s remaining: 273ms 918: learn: 0.2250015 total: 3.06s remaining: 270ms 919: learn: 0.2249589 total: 3.06s remaining: 266ms 920: learn: 0.2248454 total: 3.07s remaining: 263ms 921: learn: 0.2246542 total: 3.07s remaining: 260ms 922: learn: 0.2243981 total: 3.07s remaining: 256ms 923: learn: 0.2242925 total: 3.08s remaining: 253ms 924: learn: 0.2240367 total: 3.08s remaining: 250ms 925: learn: 0.2236834 total: 3.08s remaining: 246ms 926: learn: 0.2234257 total: 3.09s remaining: 243ms 927: learn: 0.2230992 total: 3.09s remaining: 240ms 928: learn: 0.2230899 total: 3.1s remaining: 237ms 929: learn: 0.2228504 total: 3.1s remaining: 233ms 930: learn: 0.2227112 total: 3.1s remaining: 230ms 931: learn: 0.2224408 total: 3.11s remaining: 227ms 932: learn: 0.2222795 total: 3.11s remaining: 223ms 933: learn: 0.2220837 total: 3.11s remaining: 220ms 934: learn: 0.2218455 total: 3.12s remaining: 217ms 935: learn: 0.2216053 total: 3.12s remaining: 213ms 936: learn: 0.2212786 total: 3.12s remaining: 210ms 937: learn: 0.2210942 total: 3.13s remaining: 207ms 938: learn: 0.2207330 total: 3.13s remaining: 203ms 939: learn: 0.2204996 total: 3.13s remaining: 200ms 940: learn: 0.2201419 total: 3.14s remaining: 197ms 941: learn: 0.2197450 total: 3.14s remaining: 194ms 942: learn: 0.2195649 total: 3.15s remaining: 190ms 943: learn: 0.2194141 total: 3.15s remaining: 187ms 944: learn: 0.2192079 total: 3.15s remaining: 184ms 945: learn: 0.2189989 total: 3.16s remaining: 180ms 946: learn: 0.2187927 total: 3.16s remaining: 177ms 947: learn: 0.2186379 total: 3.16s remaining: 174ms 948: learn: 0.2184722 total: 3.17s remaining: 170ms 949: learn: 0.2183393 total: 3.17s remaining: 167ms 950: learn: 0.2181005 total: 3.17s remaining: 164ms 951: learn: 0.2179581 total: 3.18s remaining: 160ms 952: learn: 0.2177392 total: 3.18s remaining: 157ms 953: learn: 0.2174948 total: 3.18s remaining: 154ms 954: learn: 0.2173059 total: 3.19s remaining: 150ms 955: learn: 0.2171342 total: 3.19s remaining: 147ms 956: learn: 0.2169922 total: 3.19s remaining: 143ms 957: learn: 0.2167591 total: 3.2s remaining: 140ms 958: learn: 0.2164620 total: 3.2s remaining: 137ms 959: learn: 0.2161386 total: 3.2s remaining: 133ms 960: learn: 0.2161191 total: 3.21s remaining: 130ms 961: learn: 0.2161055 total: 3.21s remaining: 127ms 962: learn: 0.2158105 total: 3.21s remaining: 123ms 963: learn: 0.2154934 total: 3.22s remaining: 120ms 964: learn: 0.2153511 total: 3.22s remaining: 117ms 965: learn: 0.2150633 total: 3.23s remaining: 114ms 966: learn: 0.2147549 total: 3.23s remaining: 110ms 967: learn: 0.2144673 total: 3.23s remaining: 107ms 968: learn: 0.2141893 total: 3.24s remaining: 104ms 969: learn: 0.2140516 total: 3.24s remaining: 100ms 970: learn: 0.2140342 total: 3.24s remaining: 96.9ms 971: learn: 0.2139076 total: 3.25s remaining: 93.6ms 972: learn: 0.2135383 total: 3.25s remaining: 90.2ms 973: learn: 0.2132737 total: 3.26s remaining: 86.9ms 974: learn: 0.2130593 total: 3.26s remaining: 83.6ms 975: learn: 0.2127680 total: 3.26s remaining: 80.3ms 976: learn: 0.2125970 total: 3.27s remaining: 76.9ms 977: learn: 0.2123460 total: 3.27s remaining: 73.6ms 978: learn: 0.2121052 total: 3.27s remaining: 70.2ms 979: learn: 0.2118347 total: 3.28s remaining: 66.9ms 980: learn: 0.2114822 total: 3.28s remaining: 63.6ms 981: learn: 0.2112691 total: 3.29s remaining: 60.2ms 982: learn: 0.2110438 total: 3.29s remaining: 56.9ms 983: learn: 0.2108138 total: 3.29s remaining: 53.5ms 984: learn: 0.2106651 total: 3.29s remaining: 50.2ms 985: learn: 0.2104341 total: 3.3s remaining: 46.9ms 986: learn: 0.2104232 total: 3.3s remaining: 43.5ms 987: learn: 0.2103155 total: 3.31s remaining: 40.2ms 988: learn: 0.2100220 total: 3.31s remaining: 36.8ms 989: learn: 0.2098433 total: 3.31s remaining: 33.5ms 990: learn: 0.2094910 total: 3.32s remaining: 30.1ms 991: learn: 0.2094807 total: 3.32s remaining: 26.8ms 992: learn: 0.2093171 total: 3.33s remaining: 23.4ms 993: learn: 0.2090154 total: 3.33s remaining: 20.1ms 994: learn: 0.2088564 total: 3.33s remaining: 16.7ms 995: learn: 0.2085709 total: 3.33s remaining: 13.4ms 996: learn: 0.2083527 total: 3.34s remaining: 10ms 997: learn: 0.2080872 total: 3.34s remaining: 6.7ms 998: learn: 0.2080622 total: 3.35s remaining: 3.35ms 999: learn: 0.2077081 total: 3.35s remaining: 0us
df_compare
| Model | Mean Squared Error | Mean Absolute Error | R2 Score | Training Time (s) | |
|---|---|---|---|---|---|
| 0 | Ridge Regression | 0.749019 | 0.601173 | 0.981691 | 0.005195 |
| 1 | Linear Regression | 0.767788 | 0.612101 | 0.981232 | 0.013318 |
| 2 | Lasso Regression | 1.101427 | 0.785124 | 0.973077 | 0.092810 |
| 3 | Elastic Net | 1.12101 | 0.749384 | 0.972598 | 0.098538 |
| 4 | Gradient Boosting | 88.690089 | 7.57935 | -1.167937 | 2.154962 |
| 5 | AdaBoost | 91.071163 | 7.715902 | -1.226139 | 0.682553 |
| 6 | Random Forest | 93.276696 | 7.841589 | -1.280051 | 3.758919 |
| 11 | CatBoost | 94.177261 | 7.912422 | -1.302065 | 3.500850 |
| 7 | Decision Tree | 95.967771 | 8.005216 | -1.345832 | 0.068659 |
| 8 | XGBoost | 103.064918 | 8.379896 | -1.519314 | 0.464069 |
| 9 | K-Neighbors Regressor | 330.959365 | 16.750109 | -7.089956 | 0.000567 |
| 10 | SVR | 349.436131 | 17.668843 | -7.541601 | 0.134333 |
We explored a variety of regression models to predict stock prices, including Linear Regression, Ridge Regression, Lasso Regression, Elastic Net, Support Vector Regression (SVR), K-Neighbors Regressor, Decision Tree, Random Forest, Gradient Boosting, AdaBoost, XGBoost, and CatBoost.
Ridge Regression, Linear Regression, Lasso Regression, and Elastic Net models have higher accuracy, achieving R2 scores between 0.97 and 0.98 while decision trees and boosting algorithmns can't predict accurately.
In following sections we will retrain Ridge Regression, Linear Regression, Lasso Regression, and Elastic Net models, aiming to enhance their accuracy and performance.
# Train the linear regression model
lr_model_base = LinearRegression()
lr_model_base.fit(X_train_scaled, y_train)
# Make predictions on the scaled test set
lr_pred_base = lr_model_base.predict(X_test_scaled)
prediction_df = pd.DataFrame()
prediction_df['date'] = df[df.date.dt.year>=2020]['date']
prediction_df['y_test'] = y_test
prediction_df['lr_pred_base'] = lr_pred_base
prediction_df.head()
| date | y_test | lr_pred_base | |
|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 |
lr_score_base = evaluate_regression_model(y_test, lr_pred_base)
Mean Squared Error (MSE): 0.768 Root Mean Squared Error (RMSE): 0.876 Mean Absolute Error (MAE): 0.612 R-squared (R2): 0.981
lr_score_base
{'MSE': 0.7677881743434356,
'RMSE': 0.8762352277462003,
'MAE': 0.612101458244357,
'R2': 0.9812322205481261}
plot_regression_accuracy(y_test, lr_pred_base)
plot_predictions(df,lr_pred_base)
lr_base_feature_importance = plot_feature_importance(lr_model_base,X_train,20)
lr_base_feature_importance[:15]
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_5d_avg | 34.506178 |
| 1 | sma_5 | 25.998560 |
| 2 | ema_9 | 21.885541 |
| 3 | adj_close_1d_ago | 9.666451 |
| 4 | adj_close_15d_avg | 8.627124 |
| 5 | close_5d_ago | 8.502948 |
| 6 | close_1d_ago | 7.337071 |
| 7 | adj_close_3d_avg | 7.122290 |
| 8 | low_5d_avg | 5.716135 |
| 9 | low_10d_avg | 5.542512 |
| 10 | open_5d_avg | 5.011481 |
| 11 | adj_close_5d_ago | 4.463160 |
| 12 | adj_close_7d_avg | 4.352698 |
| 13 | adj_close_1w_ago | 4.039827 |
| 14 | adj_close_3d_ago | 3.904271 |
keep_cols20 = lr_base_feature_importance[:20]['Feature'].tolist()
X_train20 = X_train[keep_cols20]
X_test20 = X_test[keep_cols20]
scaler = StandardScaler()
X_train_scaled20 = scaler.fit_transform(X_train20)
X_test_scaled20 = scaler.transform(X_test20)
# Train the linear regression model
lr_model20 = LinearRegression()
lr_model20.fit(X_train_scaled20, y_train)
# Make predictions on the scaled test set
lr_pred20 = lr_model20.predict(X_test_scaled20)
lr_score20 = evaluate_regression_model(y_test, lr_pred20)
Mean Squared Error (MSE): 0.777 Root Mean Squared Error (RMSE): 0.882 Mean Absolute Error (MAE): 0.611 R-squared (R2): 0.981
prediction_df['lr_pred20'] = lr_pred20
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | |
|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 |
lr_score20
{'MSE': 0.7771456769467653,
'RMSE': 0.8815586633609617,
'MAE': 0.6109586051263984,
'R2': 0.9810034861771783}
plot_feature_importance(lr_model20,X_train20,20)
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_5d_avg | 38.854308 |
| 1 | sma_5 | 27.484310 |
| 2 | adj_close_1d_ago | 23.775534 |
| 3 | close_1d_ago | 19.321375 |
| 4 | close_5d_ago | 15.098392 |
| 5 | adj_close_5d_ago | 11.704156 |
| 6 | adj_close_15d_avg | 4.158538 |
| 7 | adj_close_1w_ago | 3.070985 |
| 8 | sma_15 | 2.762572 |
| 9 | close_1w_ago | 2.757137 |
| 10 | low_5d_avg | 2.447157 |
| 11 | open_5d_avg | 1.692342 |
| 12 | low_10d_avg | 1.600124 |
| 13 | high_5d_avg | 1.304435 |
| 14 | open_10d_avg | 1.301215 |
| 15 | adj_close_3d_ago | 1.169397 |
| 16 | ema_9 | 0.882495 |
| 17 | adj_close_30d_avg | 0.311786 |
| 18 | adj_close_7d_avg | 0.285091 |
| 19 | adj_close_3d_avg | 0.276478 |
keep_cols15 = lr_base_feature_importance[:15]['Feature'].tolist()
X_train15 = X_train[keep_cols15]
X_test15 = X_test[keep_cols15]
scaler = StandardScaler()
X_train_scaled15 = scaler.fit_transform(X_train15)
X_test_scaled15 = scaler.transform(X_test15)
# Train the linear regression model
lr_model15 = LinearRegression()
lr_model15.fit(X_train_scaled15, y_train)
# Make predictions on the scaled test set
lr_pred15 = lr_model15.predict(X_test_scaled15)
lr_score15 = evaluate_regression_model(y_test, lr_pred15)
Mean Squared Error (MSE): 0.764 Root Mean Squared Error (RMSE): 0.874 Mean Absolute Error (MAE): 0.608 R-squared (R2): 0.981
prediction_df['lr_pred15'] = lr_pred15
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | |
|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 |
lr_score15
{'MSE': 0.76431207310404,
'RMSE': 0.874249434145679,
'MAE': 0.6082563126517426,
'R2': 0.9813171902098028}
plot_feature_importance(lr_model15,X_train15,15)
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_5d_avg | 36.266693 |
| 1 | sma_5 | 26.803574 |
| 2 | adj_close_1d_ago | 22.806743 |
| 3 | close_1d_ago | 18.534424 |
| 4 | close_5d_ago | 14.111041 |
| 5 | adj_close_5d_ago | 10.688623 |
| 6 | adj_close_3d_ago | 1.161695 |
| 7 | low_5d_avg | 0.833075 |
| 8 | ema_9 | 0.494789 |
| 9 | adj_close_15d_avg | 0.421462 |
| 10 | adj_close_7d_avg | 0.379717 |
| 11 | low_10d_avg | 0.299484 |
| 12 | adj_close_1w_ago | 0.296130 |
| 13 | adj_close_3d_avg | 0.259609 |
| 14 | open_5d_avg | 0.252947 |
keep_cols10 = lr_base_feature_importance[:10]['Feature'].tolist()
X_train10 = X_train[keep_cols10]
X_test10 = X_test[keep_cols10]
scaler = StandardScaler()
X_train_scaled10 = scaler.fit_transform(X_train10)
X_test_scaled10 = scaler.transform(X_test10)
# Train the linear regression model
lr_model10 = LinearRegression()
lr_model10.fit(X_train_scaled10, y_train)
# Make predictions on the scaled test set
lr_pred10 = lr_model10.predict(X_test_scaled10)
lr_score10 = evaluate_regression_model(y_test, lr_pred10)
Mean Squared Error (MSE): 0.843 Root Mean Squared Error (RMSE): 0.918 Mean Absolute Error (MAE): 0.634 R-squared (R2): 0.979
prediction_df['lr_pred10'] = lr_pred10
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | |
|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 |
lr_score10
{'MSE': 0.8431161224362277,
'RMSE': 0.91821354947323,
'MAE': 0.6337464901517136,
'R2': 0.9793909075875864}
plot_feature_importance(lr_model10,X_train10,10)
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_1d_ago | 31.773423 |
| 1 | adj_close_5d_avg | 30.972946 |
| 2 | close_1d_ago | 25.880350 |
| 3 | sma_5 | 25.688939 |
| 4 | close_5d_ago | 5.203279 |
| 5 | adj_close_3d_avg | 2.188661 |
| 6 | adj_close_15d_avg | 1.072305 |
| 7 | ema_9 | 0.829016 |
| 8 | low_5d_avg | 0.412671 |
| 9 | low_10d_avg | 0.181438 |
ridge_model = Ridge()
# Define the hyperparameter grid to search
param_grid = {'alpha': [0.001, 0.01, 0.1, 1, 10, 100]}
# Perform GridSearchCV for hyperparameter tuning
grid_search = GridSearchCV(estimator=ridge_model, param_grid=param_grid, scoring='neg_mean_squared_error', cv=5)
grid_search.fit(X_train_scaled, y_train)
# Get the best model
best_ridge_model = grid_search.best_estimator_
# Make predictions on the test set
ridge_pred_base = best_ridge_model.predict(X_test_scaled)
# Evaluate the best model
mse = mean_squared_error(y_test, ridge_pred_base)
rmse = mean_squared_error(y_test, ridge_pred_base, squared=False)
mae = mean_absolute_error(y_test, ridge_pred_base)
r2 = r2_score(y_test, ridge_pred_base)
print("Best Ridge Regression Model:")
print(f"Best alpha: {best_ridge_model.alpha}")
print(f'Root Mean Squared Error (RMSE): {np.round(rmse,3)}')
print(f"Mean Squared Error: {np.round(mse,3)}")
print(f"Mean Absolute Error: {np.round(mae,3)}")
print(f"R2 Score: {np.round(r2,3)}")
ridge_score = {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'R2': r2
}
Best Ridge Regression Model: Best alpha: 0.001 Root Mean Squared Error (RMSE): 0.872 Mean Squared Error: 0.76 Mean Absolute Error: 0.609 R2 Score: 0.981
ridge_base_feature_importance = plot_feature_importance(best_ridge_model,X_train,20)
ridge_base_feature_importance[:20]
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_5d_avg | 29.209670 |
| 1 | sma_5 | 23.327076 |
| 2 | ema_9 | 20.052079 |
| 3 | adj_close_1d_ago | 8.608975 |
| 4 | close_5d_ago | 7.815929 |
| 5 | adj_close_15d_avg | 6.758087 |
| 6 | close_1d_ago | 6.383786 |
| 7 | adj_close_3d_avg | 6.095190 |
| 8 | low_10d_avg | 5.185958 |
| 9 | high_5d_avg | 4.577397 |
| 10 | low_5d_avg | 4.530334 |
| 11 | adj_close_5d_ago | 4.444782 |
| 12 | open_5d_avg | 4.252683 |
| 13 | sma_10 | 4.082776 |
| 14 | adj_close_30d_avg | 3.729390 |
| 15 | macd | 3.581098 |
| 16 | sma_15 | 3.507948 |
| 17 | high_7d_avg | 3.355806 |
| 18 | open_10d_avg | 3.126070 |
| 19 | high_15d_avg | 3.066558 |
prediction_df['ridge_pred_base'] = ridge_pred_base
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | |
|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 |
keep_cols20 = ridge_base_feature_importance[:20]['Feature'].tolist()
X_train20 = X_train[keep_cols20]
X_test20 = X_test[keep_cols20]
scaler = StandardScaler()
X_train_scaled20 = scaler.fit_transform(X_train20)
X_test_scaled20 = scaler.transform(X_test20)
# Train model
ridge_model20 = Ridge(alpha=0.001)
ridge_model20.fit(X_train_scaled20, y_train)
# Make predictions on the scaled test set
ridge_pred20 = ridge_model20.predict(X_test_scaled20)
ridge_score20 = evaluate_regression_model(y_test, ridge_pred20)
Mean Squared Error (MSE): 0.771 Root Mean Squared Error (RMSE): 0.878 Mean Absolute Error (MAE): 0.609 R-squared (R2): 0.981
plot_feature_importance(ridge_model20,X_train20,20)
| Feature | Importance | |
|---|---|---|
| 0 | adj_close_5d_avg | 34.865155 |
| 1 | sma_5 | 26.561371 |
| 2 | adj_close_1d_ago | 23.306249 |
| 3 | close_1d_ago | 19.054534 |
| 4 | close_5d_ago | 13.719602 |
| 5 | adj_close_5d_ago | 9.989549 |
| 6 | adj_close_15d_avg | 2.405561 |
| 7 | low_5d_avg | 2.326657 |
| 8 | sma_15 | 1.820856 |
| 9 | low_10d_avg | 1.635199 |
| 10 | open_5d_avg | 1.424137 |
| 11 | adj_close_3d_avg | 1.268768 |
| 12 | sma_10 | 1.218959 |
| 13 | high_5d_avg | 1.016692 |
| 14 | open_10d_avg | 0.924125 |
| 15 | ema_9 | 0.564495 |
| 16 | high_7d_avg | 0.189963 |
| 17 | adj_close_30d_avg | 0.186579 |
| 18 | high_15d_avg | 0.086638 |
| 19 | macd | 0.004499 |
prediction_df['ridge_pred20'] = ridge_pred20
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | |
|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 |
lasso_model = Lasso()
param_grid = {'alpha': [0.001, 0.01, 0.1, 1, 10, 100]}
# Perform GridSearchCV for hyperparameter tuning
grid_search = GridSearchCV(estimator=lasso_model, param_grid=param_grid, scoring='neg_mean_squared_error', cv=5)
grid_search.fit(X_train_scaled, y_train)
# Get the best model
best_lasso_model = grid_search.best_estimator_
# Make predictions on the test set
lasso_pred_base = best_lasso_model.predict(X_test_scaled)
# Evaluate the best model
mse = mean_squared_error(y_test, lasso_pred_base)
rmse = mean_squared_error(y_test, lasso_pred_base, squared=False)
mae = mean_absolute_error(y_test, lasso_pred_base)
r2 = r2_score(y_test, lasso_pred_base)
print("Best Lasso Regression Model:")
print(f"Best alpha: {best_lasso_model.alpha}")
print(f'Root Mean Squared Error (RMSE): {np.round(rmse,3)}')
print(f"Mean Squared Error: {np.round(mse,3)}")
print(f"Mean Absolute Error: {np.round(mae,3)}")
print(f"R2 Score: {np.round(r2,3)}")
lasso_score = {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'R2': r2
}
Best Lasso Regression Model: Best alpha: 0.001 Root Mean Squared Error (RMSE): 0.967 Mean Squared Error: 0.935 Mean Absolute Error: 0.662 R2 Score: 0.977
lasso_base_feature_importance = plot_feature_importance(best_lasso_model,X_train,20)
lasso_base_feature_importance[:20]
| Feature | Importance | |
|---|---|---|
| 0 | ema_9 | 4.871884 |
| 1 | macd | 1.751569 |
| 2 | macd_signal | 1.345834 |
| 3 | close_3d_ago | 0.958941 |
| 4 | open_15d_avg | 0.639682 |
| 5 | rsi | 0.562192 |
| 6 | sma_15 | 0.495187 |
| 7 | low_1d_ago | 0.416806 |
| 8 | open_2w_ago | 0.305712 |
| 9 | adj_close_3d_avg | 0.290374 |
| 10 | sma_30 | 0.269198 |
| 11 | high_2w_ago | 0.211690 |
| 12 | low_3w_ago | 0.159687 |
| 13 | open_30d_avg | 0.154355 |
| 14 | volume_5d_avg | 0.096779 |
| 15 | close_2w_ago | 0.090966 |
| 16 | high_30d_avg | 0.082343 |
| 17 | low_30d_avg | 0.076811 |
| 18 | open_3w_ago | 0.065410 |
| 19 | open_5d_ago | 0.064327 |
prediction_df['lasso_pred_base'] = lasso_pred_base
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | lasso_pred_base | |
|---|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 | 54.481490 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 | 54.207026 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 | 53.919954 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 | 53.833212 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 | 53.915187 |
keep_cols20 = lasso_base_feature_importance[:20]['Feature'].tolist()
X_train20 = X_train[keep_cols20]
X_test20 = X_test[keep_cols20]
scaler = StandardScaler()
X_train_scaled20 = scaler.fit_transform(X_train20)
X_test_scaled20 = scaler.transform(X_test20)
# Train model
lasso_model20 = Lasso(alpha=0.001)
lasso_model20.fit(X_train_scaled20, y_train)
# Make predictions on the scaled test set
lasso_pred20 = lasso_model20.predict(X_test_scaled20)
lasso_score20 = evaluate_regression_model(y_test, lasso_pred20)
Mean Squared Error (MSE): 0.95 Root Mean Squared Error (RMSE): 0.975 Mean Absolute Error (MAE): 0.667 R-squared (R2): 0.977
plot_feature_importance(lasso_model20,X_train20,20)
| Feature | Importance | |
|---|---|---|
| 0 | ema_9 | 4.359165 |
| 1 | macd | 1.683244 |
| 2 | open_15d_avg | 1.543022 |
| 3 | macd_signal | 1.304271 |
| 4 | close_3d_ago | 0.919101 |
| 5 | open_2w_ago | 0.588218 |
| 6 | rsi | 0.538464 |
| 7 | sma_30 | 0.471702 |
| 8 | low_1d_ago | 0.460356 |
| 9 | adj_close_3d_avg | 0.362687 |
| 10 | low_3w_ago | 0.172657 |
| 11 | high_2w_ago | 0.107719 |
| 12 | volume_5d_avg | 0.023766 |
| 13 | close_2w_ago | 0.012261 |
| 14 | open_3w_ago | 0.008425 |
| 15 | sma_15 | 0.000000 |
| 16 | open_30d_avg | 0.000000 |
| 17 | high_30d_avg | 0.000000 |
| 18 | low_30d_avg | 0.000000 |
| 19 | open_5d_ago | 0.000000 |
prediction_df['lasso_pred20'] = lasso_pred20
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | lasso_pred_base | lasso_pred20 | |
|---|---|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 | 54.481490 | 54.540274 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 | 54.207026 | 54.313239 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 | 53.919954 | 53.991448 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 | 53.833212 | 53.885031 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 | 53.915187 | 53.943523 |
elastic_net_model = ElasticNet()
# Define the hyperparameter grid to search
param_grid = {
'alpha': [0.001, 0.01, 0.1, 1, 10, 100],
'l1_ratio': [0.1, 0.3, 0.5, 0.7, 0.9]
}
# Perform GridSearchCV for hyperparameter tuning
grid_search = GridSearchCV(estimator=elastic_net_model, param_grid=param_grid, scoring='neg_mean_squared_error', cv=5)
grid_search.fit(X_train_scaled, y_train)
# Get the best model
best_elastic_net_model = grid_search.best_estimator_
# Make predictions on the test set
elastic_pred_base = best_elastic_net_model.predict(X_test_scaled)
# Evaluate the best model
mse = mean_squared_error(y_test, elastic_pred_base)
rmse = mean_squared_error(y_test, elastic_pred_base, squared=False)
mae = mean_absolute_error(y_test, elastic_pred_base)
r2 = r2_score(y_test, elastic_pred_base)
print("Best Elastic Net Model:")
print(f"Best alpha: {best_elastic_net_model.alpha}")
print(f"Best l1_ratio: {best_elastic_net_model.l1_ratio}")
print(f'Root Mean Squared Error (RMSE): {np.round(rmse,3)}')
print(f"Mean Squared Error: {np.round(mse,3)}")
print(f"Mean Absolute Error: {np.round(mae,3)}")
print(f"R2 Score: {np.round(r2,3)}")
elastic_score = {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'R2': r2
}
Best Elastic Net Model: Best alpha: 0.001 Best l1_ratio: 0.9 Root Mean Squared Error (RMSE): 0.964 Mean Squared Error: 0.929 Mean Absolute Error: 0.66 R2 Score: 0.977
elastic_base_feature_importance = plot_feature_importance(best_elastic_net_model,X_train,20)
elastic_base_feature_importance[:20]
| Feature | Importance | |
|---|---|---|
| 0 | ema_9 | 4.372681 |
| 1 | macd | 1.808266 |
| 2 | macd_signal | 1.391821 |
| 3 | close_3d_ago | 0.961167 |
| 4 | sma_15 | 0.856338 |
| 5 | open_15d_avg | 0.654922 |
| 6 | rsi | 0.556937 |
| 7 | low_1d_ago | 0.425022 |
| 8 | sma_30 | 0.335015 |
| 9 | open_2w_ago | 0.294117 |
| 10 | adj_close_3d_avg | 0.290065 |
| 11 | high_2w_ago | 0.224628 |
| 12 | open_30d_avg | 0.213522 |
| 13 | low_3w_ago | 0.172645 |
| 14 | volume_5d_avg | 0.103550 |
| 15 | open_5d_ago | 0.087747 |
| 16 | high_30d_avg | 0.086648 |
| 17 | low_30d_avg | 0.083314 |
| 18 | open_3w_ago | 0.076077 |
| 19 | close_2w_ago | 0.072879 |
prediction_df['elastic_pred_base'] = elastic_pred_base
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | lasso_pred_base | lasso_pred20 | elastic_pred_base | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 | 54.481490 | 54.540274 | 54.475621 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 | 54.207026 | 54.313239 | 54.203798 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 | 53.919954 | 53.991448 | 53.919188 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 | 53.833212 | 53.885031 | 53.834623 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 | 53.915187 | 53.943523 | 53.922324 |
keep_cols20 = elastic_base_feature_importance[:20]['Feature'].tolist()
X_train20 = X_train[keep_cols20]
X_test20 = X_test[keep_cols20]
scaler = StandardScaler()
X_train_scaled20 = scaler.fit_transform(X_train20)
X_test_scaled20 = scaler.transform(X_test20)
# Train model
elastic_model20 = ElasticNet(alpha=0.001,l1_ratio = 0.9)
elastic_model20.fit(X_train_scaled20, y_train)
# Make predictions on the scaled test set
elastic_pred20 = elastic_model20.predict(X_test_scaled20)
elastic_score20 = evaluate_regression_model(y_test, elastic_pred20)
Mean Squared Error (MSE): 0.96 Root Mean Squared Error (RMSE): 0.98 Mean Absolute Error (MAE): 0.667 R-squared (R2): 0.977
plot_feature_importance(elastic_model20,X_train20,20)
| Feature | Importance | |
|---|---|---|
| 0 | ema_9 | 3.674293 |
| 1 | macd | 1.750184 |
| 2 | sma_15 | 1.705084 |
| 3 | macd_signal | 1.361563 |
| 4 | close_3d_ago | 0.845265 |
| 5 | sma_30 | 0.726128 |
| 6 | rsi | 0.539805 |
| 7 | low_1d_ago | 0.501059 |
| 8 | open_2w_ago | 0.363514 |
| 9 | open_15d_avg | 0.326419 |
| 10 | adj_close_3d_avg | 0.309635 |
| 11 | high_2w_ago | 0.291731 |
| 12 | low_3w_ago | 0.195555 |
| 13 | volume_5d_avg | 0.026565 |
| 14 | low_30d_avg | 0.000549 |
| 15 | open_30d_avg | 0.000468 |
| 16 | open_5d_ago | 0.000000 |
| 17 | high_30d_avg | 0.000000 |
| 18 | open_3w_ago | 0.000000 |
| 19 | close_2w_ago | 0.000000 |
prediction_df['elastic_pred20'] = elastic_pred20
prediction_df.head()
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | lasso_pred_base | lasso_pred20 | elastic_pred_base | elastic_pred20 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 | 54.481490 | 54.540274 | 54.475621 | 54.512223 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 | 54.207026 | 54.313239 | 54.203798 | 54.293587 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 | 53.919954 | 53.991448 | 53.919188 | 53.995023 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 | 53.833212 | 53.885031 | 53.834623 | 53.869575 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 | 53.915187 | 53.943523 | 53.922324 | 53.944519 |
ela_df = pd.DataFrame([elastic_score.keys(),elastic_score.values()])
ela_df.columns = ela_df.iloc[0]
ela_df = ela_df[1:].reset_index(drop=True)
ela_df['Model'] = 'Elastic_Net with All Features'
ela_20_df = pd.DataFrame([elastic_score20.keys(),elastic_score20.values()])
ela_20_df.columns = ela_20_df.iloc[0]
ela_20_df = ela_20_df[1:].reset_index(drop=True)
ela_20_df['Model'] = 'Elastic_Net with Top 20 Features'
lasso_df = pd.DataFrame([lasso_score.keys(),lasso_score.values()])
lasso_df.columns = lasso_df.iloc[0]
lasso_df = lasso_df[1:].reset_index(drop=True)
lasso_df['Model'] = 'Lasso with All Features'
lasso_20_df = pd.DataFrame([lasso_score20.keys(),lasso_score20.values()])
lasso_20_df.columns = lasso_20_df.iloc[0]
lasso_20_df = lasso_20_df[1:].reset_index(drop=True)
lasso_20_df['Model'] = 'Lasso with Top 20 Features'
ridge_df = pd.DataFrame([ridge_score.keys(),ridge_score.values()])
ridge_df.columns = ridge_df.iloc[0]
ridge_df = ridge_df[1:].reset_index(drop=True)
ridge_df['Model'] = 'Ridge with All Features'
ridge_20_df = pd.DataFrame([ridge_score20.keys(),ridge_score20.values()])
ridge_20_df.columns = ridge_20_df.iloc[0]
ridge_20_df = ridge_20_df[1:].reset_index(drop=True)
ridge_20_df['Model'] = 'Ridge with Top 20 Features'
lr_base_df = pd.DataFrame([lr_score_base.keys(),lr_score_base.values()])
lr_base_df.columns = lr_base_df.iloc[0]
lr_base_df = lr_base_df[1:].reset_index(drop=True)
lr_base_df['Model'] = 'Linear Reg. with All Features'
lr_20_df = pd.DataFrame([lr_score20.keys(),lr_score20.values()])
lr_20_df.columns = lr_20_df.iloc[0]
lr_20_df = lr_20_df[1:].reset_index(drop=True)
lr_20_df['Model'] = 'Linear Reg. with Top 20 Features'
lr_15_df = pd.DataFrame([lr_score15.keys(),lr_score15.values()])
lr_15_df.columns = lr_15_df.iloc[0]
lr_15_df = lr_15_df[1:].reset_index(drop=True)
lr_15_df['Model'] = 'Linear Reg. with Top 15 Features'
lr_10_df = pd.DataFrame([lr_score10.keys(),lr_score10.values()])
lr_10_df.columns = lr_10_df.iloc[0]
lr_10_df = lr_10_df[1:].reset_index(drop=True)
lr_10_df['Model'] = 'Linear Reg. with Top 10 Features'
df_compare = pd.concat([ela_df,lasso_df,ridge_df,ela_20_df,lasso_20_df,ridge_20_df,
lr_base_df,lr_20_df,lr_15_df,lr_10_df]).sort_values(by=['R2'],ascending=False).reset_index(drop=True)
df_compare
| MSE | RMSE | MAE | R2 | Model | |
|---|---|---|---|---|---|
| 0 | 0.76015 | 0.871866 | 0.60886 | 0.981419 | Ridge with All Features |
| 1 | 0.764312 | 0.874249 | 0.608256 | 0.981317 | Linear Reg. with Top 15 Features |
| 2 | 0.767788 | 0.876235 | 0.612101 | 0.981232 | Linear Reg. with All Features |
| 3 | 0.771305 | 0.87824 | 0.60942 | 0.981146 | Ridge with Top 20 Features |
| 4 | 0.777146 | 0.881559 | 0.610959 | 0.981003 | Linear Reg. with Top 20 Features |
| 5 | 0.843116 | 0.918214 | 0.633746 | 0.979391 | Linear Reg. with Top 10 Features |
| 6 | 0.928855 | 0.963771 | 0.659537 | 0.977295 | Elastic_Net with All Features |
| 7 | 0.935026 | 0.966967 | 0.661521 | 0.977144 | Lasso with All Features |
| 8 | 0.949732 | 0.974542 | 0.666845 | 0.976785 | Lasso with Top 20 Features |
| 9 | 0.959534 | 0.979558 | 0.667449 | 0.976545 | Elastic_Net with Top 20 Features |
After retraining the models with different alpha and input features, Ridge regression model with alpha 0.001 and all features performed best among others.
Mean Squared Error (MSE) 0.76015: MSE measures the average squared difference between predicted and actual values. In this case, the MSE of 0.76015 is relatively low, indicating that, on average, the squared errors between predicted and actual values are small. Lower MSE values suggest better accuracy.
Root Mean Squared Error (RMSE) 0.871866: RMSE is the square root of the MSE and provides a measure of the average magnitude of the errors. A lower RMSE (0.871866) signifies that, on average, the model's predictions are close to the actual values. It is in the same unit as the target variable.
Mean Absolute Error (MAE) 0.60886: MAE measures the average absolute difference between predicted and actual values. With an MAE of 0.60886, the model's predictions, on average, deviate by approximately 0.60886 units from the actual values. Lower MAE values indicate better accuracy.
R-squared (R2) 0.981419: R2 represents the proportion of variance in the target variable that is predictable from the independent variables. An R2 value of 0.981419 is exceptionally high, indicating that the model explains about 98.14% of the variance in the closing stock prices. A higher R2 value suggests a better accuracy.
In summary, the provided accuracy scores collectively suggest that the model performs exceptionally well. The low MSE, RMSE, MAE and high R2 score indicate that the model's predictions are close to the actual values.
prediction_df
| date | y_test | lr_pred_base | lr_pred20 | lr_pred15 | lr_pred10 | ridge_pred_base | ridge_pred20 | lasso_pred_base | lasso_pred20 | elastic_pred_base | elastic_pred20 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1729 | 2020-01-02 | 54.240002 | 54.234807 | 54.271747 | 54.203303 | 53.928125 | 54.233949 | 54.297031 | 54.481490 | 54.540274 | 54.475621 | 54.512223 |
| 1730 | 2020-01-03 | 54.150002 | 54.462575 | 54.607093 | 54.564688 | 54.251539 | 54.459871 | 54.574923 | 54.207026 | 54.313239 | 54.203798 | 54.293587 |
| 1731 | 2020-01-06 | 53.919998 | 54.392052 | 54.587532 | 54.570543 | 54.299935 | 54.363214 | 54.606500 | 53.919954 | 53.991448 | 53.919188 | 53.995023 |
| 1732 | 2020-01-07 | 54.049999 | 53.866825 | 53.925943 | 53.969683 | 54.039427 | 53.879662 | 53.954122 | 53.833212 | 53.885031 | 53.834623 | 53.869575 |
| 1733 | 2020-01-08 | 54.189999 | 54.141524 | 53.925500 | 53.970832 | 54.149424 | 54.173588 | 54.033765 | 53.915187 | 53.943523 | 53.922324 | 53.944519 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 2694 | 2023-11-01 | 67.970001 | 67.284685 | 66.932767 | 67.028746 | 67.004810 | 67.325547 | 66.883598 | 66.893549 | 66.983023 | 66.898251 | 66.962332 |
| 2695 | 2023-11-02 | 68.820000 | 67.961443 | 67.903284 | 67.963962 | 67.941089 | 67.983409 | 67.964641 | 67.421988 | 67.462660 | 67.435336 | 67.464215 |
| 2696 | 2023-11-03 | 68.239998 | 68.862446 | 69.001802 | 69.034374 | 68.480353 | 68.837447 | 68.941355 | 68.232541 | 68.282736 | 68.249280 | 68.273542 |
| 2697 | 2023-11-06 | 68.489998 | 68.073511 | 68.184944 | 68.248638 | 67.791667 | 68.100132 | 68.219073 | 68.189779 | 68.252203 | 68.190293 | 68.232252 |
| 2698 | 2023-11-07 | 69.019997 | 68.266807 | 68.945917 | 69.051137 | 68.728268 | 68.332505 | 69.018382 | 68.466769 | 68.568573 | 68.473424 | 68.545360 |
970 rows × 12 columns
plt.figure(figsize=(20, 10))
sns.lineplot(x = prediction_df.date, y=prediction_df.y_test,label='y_test')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred_base,label='lr_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred20,label='lr_pred20')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred15,label='lr_pred15')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred10,label='lr_pred10')
sns.lineplot(x = prediction_df.date, y=prediction_df.ridge_pred_base,label='ridge_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.ridge_pred20,label='ridge_pred20')
sns.lineplot(x = prediction_df.date, y=prediction_df.lasso_pred_base,label='lasso_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.lasso_pred20,label='lasso_pred20')
sns.lineplot(x = prediction_df.date, y=prediction_df.elastic_pred_base,label='elastic_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.elastic_pred20,label='elastic_pred20')
plt.legend(prop={'size': 14, 'weight': 'bold'})
plt.title('Model Prediction Comparison', fontsize=16)
plt.ylabel('Prediction', fontsize=14)
plt.xlabel('Date', fontsize=14)
plt.show()
plt.figure(figsize=(20, 10))
sns.lineplot(x = prediction_df.date, y=prediction_df.y_test,label='y_test')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred_base,label='lr_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred20,label='lr_pred20')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred15,label='lr_pred15')
sns.lineplot(x = prediction_df.date, y=prediction_df.lr_pred10,label='lr_pred10')
plt.legend(prop={'size': 14, 'weight': 'bold'})
plt.title('Model Prediction Comparison', fontsize=16)
plt.ylabel('Prediction', fontsize=14)
plt.xlabel('Date', fontsize=14)
plt.show()
plt.figure(figsize=(20, 10))
sns.lineplot(x = prediction_df.date, y=prediction_df.y_test,label='y_test')
sns.lineplot(x = prediction_df.date, y=prediction_df.ridge_pred_base,label='ridge_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.ridge_pred20,label='ridge_pred20')
plt.legend(prop={'size': 14, 'weight': 'bold'})
plt.title('Model Prediction Comparison', fontsize=16)
plt.ylabel('Prediction', fontsize=14)
plt.xlabel('Date', fontsize=14)
plt.show()
plt.figure(figsize=(20, 10))
sns.lineplot(x = prediction_df.date, y=prediction_df.y_test,label='y_test')
sns.lineplot(x = prediction_df.date, y=prediction_df.lasso_pred_base,label='lasso_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.lasso_pred20,label='lasso_pred20')
plt.legend(prop={'size': 14, 'weight': 'bold'})
plt.title('Model Prediction Comparison', fontsize=16)
plt.ylabel('Prediction', fontsize=14)
plt.xlabel('Date', fontsize=14)
plt.show()
plt.figure(figsize=(20, 10))
sns.lineplot(x = prediction_df.date, y=prediction_df.y_test,label='y_test')
sns.lineplot(x = prediction_df.date, y=prediction_df.elastic_pred_base,label='elastic_pred_base')
sns.lineplot(x = prediction_df.date, y=prediction_df.elastic_pred20,label='elastic_pred20')
plt.legend(prop={'size': 14, 'weight': 'bold'})
plt.title('Model Prediction Comparison', fontsize=16)
plt.ylabel('Prediction', fontsize=14)
plt.xlabel('Date', fontsize=14)
plt.show()
# target column is next day's close price
y_train = train_df['close_1d_next'].copy()
X_train = train_df.drop(['close_1d_next'], 1)
# target column is next day's close price
y_test = test_df['close_1d_next'].copy()
X_test = test_df.drop(['close_1d_next'], 1)
def train_ridge_regression(X_train,X_test,y_train,y_test):
ridge_model = Ridge(alpha= 0.001)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Train model
ridge_model = Ridge(alpha=0.001)
ridge_model.fit(X_train_scaled, y_train)
# Make predictions on the scaled test set
ridge_pred = ridge_model.predict(X_test_scaled)
ridge_score = evaluate_regression_model2(y_test, ridge_pred)
return ridge_model,ridge_pred,ridge_score
ridge_model, ridge_pred, ridge_score = train_ridge_regression(X_train,X_test,y_train,y_test)
ridge_score
{'MSE': 0.7601496984667533,
'RMSE': 0.8718656424396785,
'MAE': 0.6088603450778082,
'R2': 0.981418935107418}
ridge_pred[:15]
array([54.23394893, 54.45987053, 54.36321399, 53.87966228, 54.1735875 ,
54.09581315, 54.16195972, 54.70517195, 54.39406442, 54.76593117,
55.13141383, 55.07110404, 55.40303391, 55.39633407, 55.50978756])
plot_regression_accuracy(y_test, ridge_pred)
plot_predictions(df,ridge_pred)
The residual, scatter, and time series line charts above clearly show that the predicted values are very close to the actual values. These visualizations confirm that the model is very good at making accurate predictions, highlighting its strong performance and reliability in understanding the details of the data.
def preprocess_data(df):
df['ema_9'] = df['close'].ewm(9).mean().shift()
df['sma_5'] = df['close'].rolling(5).mean().shift()
df['sma_10'] = df['close'].rolling(10).mean().shift()
df['sma_15'] = df['close'].rolling(15).mean().shift()
df['sma_30'] = df['close'].rolling(30).mean().shift()
df['rsi'] = rsi(df) #.fillna(0)
df['mfi'] = mfi(df, 14)
EMA_12 = pd.Series(df['close'].ewm(span=12, min_periods=12).mean())
EMA_26 = pd.Series(df['close'].ewm(span=26, min_periods=26).mean())
df['macd'] = pd.Series(EMA_12 - EMA_26)
df['macd_signal'] = pd.Series(df.macd.ewm(span=9, min_periods=9).mean())
df['close_1d_next'] = df['close'].shift(-1)
df['close_1d_ago'] = df['close'].shift(1)
df['close_3d_ago'] = df['close'].shift(3)
df['close_5d_ago'] = df['close'].shift(5)
df['close_1w_ago'] = df['close'].shift(7)
df['close_2w_ago'] = df['close'].shift(14)
df['close_3w_ago'] = df['close'].shift(21)
df['close_4w_ago'] = df['close'].shift(28)
df['adj_close_1d_ago'] = df['adj close'].shift(1)
df['adj_close_3d_ago'] = df['adj close'].shift(3)
df['adj_close_5d_ago'] = df['adj close'].shift(5)
df['adj_close_1w_ago'] = df['adj close'].shift(7)
df['adj_close_2w_ago'] = df['adj close'].shift(14)
df['adj_close_3w_ago'] = df['adj close'].shift(21)
df['adj_close_4w_ago'] = df['adj close'].shift(28)
df['open_1d_ago'] = df['open'].shift(1)
df['open_3d_ago'] = df['open'].shift(3)
df['open_5d_ago'] = df['open'].shift(5)
df['open_1w_ago'] = df['open'].shift(7)
df['open_2w_ago'] = df['open'].shift(14)
df['open_3w_ago'] = df['open'].shift(21)
df['open_4w_ago'] = df['open'].shift(28)
df['high_1d_ago'] = df['high'].shift(1)
df['high_3d_ago'] = df['high'].shift(3)
df['high_5d_ago'] = df['high'].shift(5)
df['high_1w_ago'] = df['high'].shift(7)
df['high_2w_ago'] = df['high'].shift(14)
df['high_3w_ago'] = df['high'].shift(21)
df['high_4w_ago'] = df['high'].shift(28)
df['low_1d_ago'] = df['low'].shift(1)
df['low_3d_ago'] = df['low'].shift(3)
df['low_5d_ago'] = df['low'].shift(5)
df['low_1w_ago'] = df['low'].shift(7)
df['low_2w_ago'] = df['low'].shift(14)
df['low_3w_ago'] = df['low'].shift(21)
df['low_4w_ago'] = df['low'].shift(28)
df['volume_1d_ago'] = df['volume'].shift(1)
df['volume_3d_ago'] = df['volume'].shift(3)
df['volume_5d_ago'] = df['volume'].shift(5)
df['volume_1w_ago'] = df['volume'].shift(7)
df['volume_2w_ago'] = df['volume'].shift(14)
df['volume_3w_ago'] = df['volume'].shift(21)
df['volume_4w_ago'] = df['volume'].shift(28)
df['open_3d_avg'] = df['open'].rolling(window=3).mean()
df['open_5d_avg'] = df['open'].rolling(window=5).mean()
df['open_7d_avg'] = df['open'].rolling(window=7).mean()
df['open_10d_avg'] = df['open'].rolling(window=10).mean()
df['open_15d_avg'] = df['open'].rolling(window=15).mean()
df['open_30d_avg'] = df['open'].rolling(window=30).mean()
df['high_3d_avg'] = df['high'].rolling(window=3).mean()
df['high_5d_avg'] = df['high'].rolling(window=5).mean()
df['high_7d_avg'] = df['high'].rolling(window=7).mean()
df['high_10d_avg'] = df['high'].rolling(window=10).mean()
df['high_15d_avg'] = df['high'].rolling(window=15).mean()
df['high_30d_avg'] = df['high'].rolling(window=30).mean()
df['low_3d_avg'] = df['low'].rolling(window=3).mean()
df['low_5d_avg'] = df['low'].rolling(window=5).mean()
df['low_7d_avg'] = df['low'].rolling(window=7).mean()
df['low_10d_avg'] = df['low'].rolling(window=10).mean()
df['low_15d_avg'] = df['low'].rolling(window=15).mean()
df['low_30d_avg'] = df['low'].rolling(window=30).mean()
df['volume_3d_avg'] = df['volume'].rolling(window=3).mean()
df['volume_5d_avg'] = df['volume'].rolling(window=5).mean()
df['volume_7d_avg'] = df['volume'].rolling(window=7).mean()
df['volume_10d_avg'] = df['volume'].rolling(window=10).mean()
df['volume_15d_avg'] = df['volume'].rolling(window=15).mean()
df['volume_30d_avg'] = df['volume'].rolling(window=30).mean()
df['adj_close_3d_avg'] = df['adj close'].rolling(window=3).mean()
df['adj_close_5d_avg'] = df['adj close'].rolling(window=5).mean()
df['adj_close_7d_avg'] = df['adj close'].rolling(window=7).mean()
df['adj_close_10d_avg'] = df['adj close'].rolling(window=10).mean()
df['adj_close_15d_avg'] = df['adj close'].rolling(window=15).mean()
df['adj_close_30d_avg'] = df['adj close'].rolling(window=30).mean()
return df
df_all = pd.read_parquet(out_loc+"stock_1d.parquet")
df_all.columns = df_all.columns.str.lower()
### keep stocks in data with min year 2013, max year 2023
stock_min_dt = pd.DataFrame(df_all.groupby('symbol')['date'].min()).reset_index().rename(columns={'date':'min_date'})
stock_max_dt = pd.DataFrame(df_all.groupby('symbol')['date'].max()).reset_index().rename(columns={'date':'max_date'})
stock_cnt_dt = pd.DataFrame(df_all.groupby('symbol')['date'].count()).reset_index().rename(columns={'date':'days_cnt'})
stock_cnt = stock_min_dt.merge(stock_max_dt,on='symbol').merge(stock_cnt_dt,on='symbol')
stock_cnt['min_year'] = stock_cnt['min_date'].dt.year
stock_cnt['max_year'] = stock_cnt['max_date'].dt.year
keep_stocks = stock_cnt[(stock_cnt['min_year']==2013)&(stock_cnt['max_year']==2023)&(stock_cnt['days_cnt']>=2500)]['symbol'].unique().tolist()
stock_cnt.head()
| symbol | min_date | max_date | days_cnt | min_year | max_year | |
|---|---|---|---|---|---|---|
| 0 | A | 2013-01-02 | 2023-11-08 | 2733 | 2013 | 2023 |
| 1 | AAL | 2013-01-02 | 2023-11-08 | 2733 | 2013 | 2023 |
| 2 | AAPL | 2013-01-02 | 2023-11-08 | 2733 | 2013 | 2023 |
| 3 | ABBV | 2013-01-02 | 2023-11-08 | 2733 | 2013 | 2023 |
| 4 | ABNB | 2020-12-10 | 2023-11-08 | 733 | 2020 | 2023 |
df_2023 = df_all[(df_all.date.dt.year==2023) & (df_all.symbol.isin(keep_stocks))]
# volume vs stocks
volume_2023 = pd.DataFrame(df_2023.groupby(['symbol','security','gics sector'])['volume'].sum()).reset_index()
volume_2023 = volume_2023.sort_values(by='volume',ascending=False).reset_index(drop=True)
volume_2023.head()
| symbol | security | gics sector | volume | |
|---|---|---|---|---|
| 0 | TSLA | Tesla, Inc. | Consumer Discretionary | 3.009291e+10 |
| 1 | AMD | AMD | Information Technology | 1.342035e+10 |
| 2 | AMZN | Amazon | Consumer Discretionary | 1.305160e+10 |
| 3 | AAPL | Apple Inc. | Information Technology | 1.303964e+10 |
| 4 | F | Ford Motor Company | Consumer Discretionary | 1.278319e+10 |
# volume vs sectors
sector_2023 = pd.DataFrame(df_2023.groupby(['gics sector'])['volume'].sum()).reset_index()
sector_2023 = sector_2023.sort_values(by='volume',ascending=False).reset_index(drop=True)
sector_2023
| gics sector | volume | |
|---|---|---|
| 0 | Consumer Discretionary | 9.171407e+10 |
| 1 | Information Technology | 8.888840e+10 |
| 2 | Financials | 6.728113e+10 |
| 3 | Communication Services | 5.267892e+10 |
| 4 | Health Care | 3.755560e+10 |
| 5 | Industrials | 3.672492e+10 |
| 6 | Energy | 3.245171e+10 |
| 7 | Consumer Staples | 2.824873e+10 |
| 8 | Utilities | 2.214882e+10 |
| 9 | Materials | 1.432867e+10 |
| 10 | Real Estate | 1.318748e+10 |
# filter top 5 sectors with highest volume in 2023
sector_list = sector_2023[:5]['gics sector'].tolist()
stock_list = []
num_stocks = 5
# stocks with highest volume in each sector
for sec in sector_list:
stock_list.append(volume_2023[volume_2023['gics sector']==sec]['symbol'][:num_stocks].tolist())
stock_list = [item for sublist in stock_list for item in sublist]
len(stock_list)
25
df_stocks = df_all[df_all['symbol'].isin(stock_list)].reset_index(drop=True)
df_stocks.head()
| date | open | high | low | close | adj close | volume | symbol | security | gics sector | gics sub-industry | headquarters location | date added | cik | founded | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2013-01-02 | 18.003504 | 18.193193 | 17.931683 | 18.099348 | 18.099348 | 101550348.0 | GOOGL | Alphabet Inc. (Class A) | Communication Services | Interactive Media & Services | Mountain View, California | 2014-04-03 | 1652044 | 1998 |
| 1 | 2013-01-03 | 18.141392 | 18.316566 | 18.036036 | 18.109859 | 18.109859 | 92635272.0 | GOOGL | Alphabet Inc. (Class A) | Communication Services | Interactive Media & Services | Mountain View, California | 2014-04-03 | 1652044 | 1998 |
| 2 | 2013-01-04 | 18.251753 | 18.555305 | 18.210211 | 18.467718 | 18.467718 | 110429460.0 | GOOGL | Alphabet Inc. (Class A) | Communication Services | Interactive Media & Services | Mountain View, California | 2014-04-03 | 1652044 | 1998 |
| 3 | 2013-01-07 | 18.404655 | 18.503002 | 18.282784 | 18.387136 | 18.387136 | 66161772.0 | GOOGL | Alphabet Inc. (Class A) | Communication Services | Interactive Media & Services | Mountain View, California | 2014-04-03 | 1652044 | 1998 |
| 4 | 2013-01-08 | 18.406906 | 18.425926 | 18.128880 | 18.350851 | 18.350851 | 66976956.0 | GOOGL | Alphabet Inc. (Class A) | Communication Services | Interactive Media & Services | Mountain View, California | 2014-04-03 | 1652044 | 1998 |
stock_compare = []
for stock in stock_list:
stock_data = df_stocks[df_stocks['symbol'] == stock]
stock_data = preprocess_data(stock_data)
stock_data = stock_data.dropna().reset_index(drop=True)
# Split the DataFrame into training and testing sets
train_df_temp = stock_data[stock_data.date.dt.year<2020]
test_df_temp = stock_data[stock_data.date.dt.year>=2020]
drop_cols1 = ['date','open','high','low','close','adj close','volume','symbol','security',
'gics sector','gics sub-industry','headquarters location','date added','cik','founded']
train_df_temp = train_df_temp.drop(drop_cols1, 1)
test_df_temp = test_df_temp.drop(drop_cols1, 1)
# target column is next day's close price
y_train_temp = train_df_temp['close_1d_next'].copy()
X_train_temp = train_df_temp.drop(['close_1d_next'], 1)
# target column is next day's close price
y_test_temp = test_df_temp['close_1d_next'].copy()
X_test_temp = test_df_temp.drop(['close_1d_next'], 1)
# print(stock, len(X_train), len(X_test), len(y_train), len(y_test))
temp_model, temp_pred, temp_score = train_ridge_regression(X_train_temp,X_test_temp,y_train_temp,y_test_temp)
score_df = pd.DataFrame([temp_score.keys(),temp_score.values()])
score_df.columns = score_df.iloc[0]
score_df = score_df[1:].reset_index(drop=True)
score_df['symbol'] = stock
stock_compare.append(score_df)
compare_df = pd.concat(stock_compare).sort_values(by='R2',ascending=False).reset_index(drop =True)
compare_df
| MSE | RMSE | MAE | R2 | symbol | |
|---|---|---|---|---|---|
| 0 | 56.919689 | 7.544514 | 5.112978 | 0.995061 | NVDA |
| 1 | 0.559182 | 0.747785 | 0.529467 | 0.993251 | VZ |
| 2 | 7.770865 | 2.787627 | 2.107652 | 0.992581 | AAPL |
| 3 | 5.039335 | 2.244846 | 1.670374 | 0.992108 | GOOG |
| 4 | 4.98575 | 2.232879 | 1.654225 | 0.992002 | GOOGL |
| 5 | 46.381058 | 6.810364 | 4.690443 | 0.990596 | META |
| 6 | 2.057386 | 1.434359 | 1.020137 | 0.990386 | CVS |
| 7 | 0.147953 | 0.384647 | 0.271314 | 0.990362 | F |
| 8 | 1.27562 | 1.129434 | 0.839937 | 0.98989 | GM |
| 9 | 28.055837 | 5.296776 | 3.992424 | 0.989314 | MSFT |
| 10 | 84.835254 | 9.210606 | 6.497099 | 0.988764 | TSLA |
| 11 | 0.575533 | 0.758639 | 0.55086 | 0.988709 | PFE |
| 12 | 0.581909 | 0.76283 | 0.577919 | 0.988411 | BAC |
| 13 | 0.26222 | 0.512075 | 0.377068 | 0.988041 | KEY |
| 14 | 1.80572 | 1.343771 | 0.958788 | 0.987257 | INTC |
| 15 | 0.991949 | 0.995966 | 0.734726 | 0.987031 | WFC |
| 16 | 11.147927 | 3.338851 | 2.469883 | 0.985621 | AMZN |
| 17 | 0.174755 | 0.418037 | 0.281833 | 0.983153 | T |
| 18 | 10.145635 | 3.185221 | 2.328928 | 0.982675 | AMD |
| 19 | 2.409729 | 1.55233 | 1.141577 | 0.981822 | C |
| 20 | 0.124353 | 0.352637 | 0.259468 | 0.979688 | HBAN |
| 21 | 0.870412 | 0.932958 | 0.684029 | 0.978487 | BMY |
| 22 | 2.087087 | 1.444675 | 1.154954 | 0.96863 | CCL |
| 23 | 4.398395 | 2.097235 | 1.52279 | 0.967914 | JNJ |
| 24 | 4.070663 | 2.017589 | 1.528408 | 0.589411 | VTRS |
The final phase of the project involved applying the developed model to real-world scenarios. By identifying the top 5 industries with the highest volume in 2023, we ensured that our predictions were grounded in current market dynamics. The subsequent selection of 5 stocks within each industry added a layer of practicality to our findings.
The model's stellar performance on NVDA, AAPL, VZ, GOOG, and GOOGL proved its robustness in diverse market conditions. Simultaneously, the challenges encountered with VTRS opened up opportunities for further investigation into the factors contributing to its underperformance.